[BZOJ4127] Abs

题目描述

Description

给定一棵树,设计数据结构支持以下操作
1 u v d  表示将路径 (u,v) 加d
2 u v 表示询问路径 (u,v) 上点权绝对值的和

Input

第一行两个整数n和m,表示结点个数和操作数
接下来一行n个整数a_i,表示点i的权值
接下来n-1行,每行两个整数u,v表示存在一条(u,v)的边
接下来m行,每行一个操作,输入格式见题目描述

Output

对于每个询问输出答案

Sample Input

4 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4

Sample Output

10
13
9

HINT

对于100%的数据,n,m <= 10^5 且 0<= d,|a_i|<= 10^8

题目分析

树剖裸题无疑 但是线段树部分要怎么写?
注意到边的权值只会改大不会改小 那我们可以这样做:
存储区间最大的负数 区间正数的和 区间负数的和 区间正数的个数 区间负数的个数 lazy标记
当一段区间上累加的lazy标记大于等于区间最大负数的绝对值时,暴力重构子树(其实就是把lazy暴力推下去)
否则就根据自己储存的内容直接修改即可
时间均摊复杂度

#include <cstdio>
#include <cstring>
#include <set>
#include <map>
#include <vector>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
int n,m;
int tot;
int head[100010],to[200010],net[200010],val[200010];
void add(int x,int y)
{
    net[++tot]=head[x],head[x]=tot,to[tot]=y;
}
int deep[100100],fa[100100],size[100010],top[100010],son[100100];
int number[100100],value[100100],cnt;
void dfs(int x)
{
    deep[x]=deep[fa[x]]+1,size[x]=1;
    for(int i=head[x];i;i=net[i])
        if(to[i]!=fa[x])
        {
            fa[to[i]]=x,dfs(to[i]);
            size[x]+=size[to[i]];
            if(size[son[x]]<size[to[i]]) son[x]=to[i];
        }
}
void dfs2(int x,int temp)
{
    top[x]=temp,number[x]=++cnt;
    value[cnt]=val[x];
    if(son[x]) dfs2(son[x],temp);
    for(int i=head[x];i;i=net[i])
        if(to[i]!=son[x]&&to[i]!=fa[x])
            dfs2(to[i],to[i]);
}
struct your
{
    int x,y,num[2];
    long long sum[2];
    int maxx;
    long long add;
}a[100010*4];
void pushup(int num)
{
    a[num].num[1]=a[num<<1].num[1]+a[num<<1|1].num[1];
    a[num].num[0]=a[num<<1].num[0]+a[num<<1|1].num[0];
    a[num].sum[1]=(long long) a[num<<1].sum[1]+a[num<<1|1].sum[1];
    a[num].sum[0]=(long long) a[num<<1].sum[0]+a[num<<1|1].sum[0];
    int tmp=-1*0x7f7f7f7f;
    if(a[num<<1].maxx<0) tmp=max(tmp,a[num<<1].maxx);
    if(a[num<<1|1].maxx<0) tmp=max(tmp,a[num<<1|1].maxx);
    if(tmp==-0x7f7f7f7f) a[num].maxx=0;
    else a[num].maxx=tmp;
}
void build(int dx,int dy,int num)
{
    a[num].x=dx,a[num].y=dy;
    if(dx==dy)
    {
        a[num].sum[(value[dx]<0)?0:1]=value[dx];
        if(value[dx]<0) a[num].maxx=value[dx];
        else a[num].maxx=0;
        a[num].num[value[dx]>=0]=1;
        return  ;
    }
    int mid=(dx+dy)>>1;
    build(dx,mid,num<<1),build(mid+1,dy,num<<1|1);
    pushup(num);
}
void pushdown(int num)
{
    a[num<<1].sum[0]+=(long long) a[num<<1].num[0]*a[num].add;
    a[num<<1].sum[1]+=(long long) a[num<<1].num[1]*a[num].add;
    a[num<<1].add+=a[num].add,a[num<<1].maxx+=a[num].add;
    a[num<<1|1].sum[0]+=(long long) a[num<<1|1].num[0]*a[num].add;
    a[num<<1|1].sum[1]+=(long long) a[num<<1|1].num[1]*a[num].add;
    a[num<<1|1].add+=a[num].add,a[num<<1|1].maxx+=a[num].add;
    a[num].add=0;
}
void rebuild(int dx,int dy,int c,int num)
{
    if(dx==dy)
    {
        int tmp=a[num].maxx+c;
        a[num].sum[(tmp<0)?0:1]=tmp;
        a[num].sum[(tmp<0)?1:0]=0;
        a[num].num[(tmp<0)?0:1]=1;
        a[num].num[(tmp<0)?1:0]=0;
        a[num].maxx=0;
        return ;
    }
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(a[num<<1].maxx<0&&a[num<<1].maxx*-1<=c) rebuild(dx,mid,c,num<<1);
    else
    {
        a[num<<1].sum[0]+=(long long) a[num<<1].num[0]*c;
        a[num<<1].sum[1]+=(long long) a[num<<1].num[1]*c;
        a[num<<1].add+=c,a[num<<1].maxx+=c;
    }
    if(a[num<<1|1].maxx<0&&a[num<<1|1].maxx*-1<=c) rebuild(mid+1,dy,c,num<<1|1);
    else
    {
        a[num<<1|1].sum[0]+=(long long) a[num<<1|1].num[0]*c;
        a[num<<1|1].sum[1]+=(long long) a[num<<1|1].num[1]*c;
        a[num<<1|1].add+=c,a[num<<1|1].maxx+=c;
    }
    pushup(num);
}
void update(int dx,int dy,int c,int num)
{
    if(a[num].x==dx&&a[num].y==dy)
    {
        if(a[num].maxx<0&&-1*a[num].maxx<=c) 
            rebuild(dx,dy,c,num);
        else
        {
            a[num].sum[0]+=(long long) a[num].num[0]*c;
            a[num].sum[1]+=(long long) a[num].num[1]*c;
            a[num].add+=c,a[num].maxx+=c;
        }
        return ;
    }
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dy<=mid) update(dx,dy,c,num<<1);
    else if(dx>mid) update(dx,dy,c,num<<1|1);
    else update(dx,mid,c,num<<1),update(mid+1,dy,c,num<<1|1);
    pushup(num);
}
long long ask(int dx,int dy,int num)
{

    if(a[num].x==dx&&a[num].y==dy) 
        return abs(a[num].sum[0])+abs(a[num].sum[1]);
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dy<=mid) return ask(dx,dy,num<<1);
    else if(dx>mid) return ask(dx,dy,num<<1|1);
    else return ask(dx,mid,num<<1)+ask(mid+1,dy,num<<1|1); 
}
void change(int x,int y,int c)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        update(number[top[x]],number[x],c,1);
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    update(number[x],number[y],c,1);
}
long long find(int x,int y)
{
    long long sum=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        sum+=ask(number[top[x]],number[x],1);
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    sum+=ask(number[x],number[y],1);
    return sum;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&val[i]);
    for(int x,y,i=1;i<n;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    dfs(1),dfs2(1,1),build(1,n,1);
    for(int x,y,c,tmp,i=1;i<=m;i++)
    {
        scanf("%d",&tmp);
        if(tmp==1)
            scanf("%d%d%d",&x,&y,&c),change(x,y,c);
        else scanf("%d%d",&x,&y),printf("%lld\n",find(x,y));
    }
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注