[BZOJ3626] [LNOI2014]LCA

题目描述

Description

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

Input

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

Output

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

Sample Input

5 2
0
0
1
1
1 4 3
1 4 2

Sample Output

8
5

HINT

共5组数据,n与q的规模分别为10000,20000,30000,40000,50000。

题目分析:

首先我们考虑一下只有一组区间的情况
我们可以发现,一个点和另一个点的LCA的深度 =在将一个点到1的路径+1后 另一个点与1路径的和
并且这个东西是可以相加减的 即sum[1,l-1]=sum[1,r]-sum[l,r];
这说明每一个区间询问是可以差分的
于是我们想到离线区间,依次将1~n到1的路径+1,遇到一个左端点就减去值 遇到右端点就再加上
这样就过了
使用树链剖分 在O(log^2n)实现查询路径和 更新路径值

#include <cstdio>
#include <cstring>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
int n,m;
int mod=201314;
struct way
{
    int x,c;
    int flag,id;
}q[120000];
int head[60000],net[110000],to[110000];
int tot;
void add(int x,int y)
{
    net[++tot]=head[x];
    to[tot]=y;
    head[x]=tot;
}
int top[60000],number[60000];
int deep[60000],son[60000],size[60000],fa[60000];
int cnt;
void dfs(int x,int temp)
{
    fa[x]=temp;
    size[x]=1;
    deep[x]=deep[temp]+1;
    for(int i=head[x];i;i=net[i])
        if(to[i]!=temp)
        {
            dfs(to[i],x);
            if(size[to[i]]>size[son[x]]) son[x]=to[i];
            size[x]+=size[to[i]];
        }
    return ;
}
void dfs2(int x,int temp)
{
    top[x]=temp;
    number[x]=++cnt;
    if(son[x]) dfs2(son[x],temp);
    for(int i=head[x];i;i=net[i])
        if(to[i]!=fa[x]&to[i]!=son[x])
            dfs2(to[i],to[i]);
}
struct your
{
    int x,y;
    int add;
    int sum,val;
}a[250000];
void pushdown(int num)
{
    a[num<<1].add=(a[num<<1].add+a[num].add)%mod;
    a[num<<1|1].add=(a[num<<1|1].add+a[num].add)%mod;
    a[num<<1].sum=(a[num<<1].sum+(a[num<<1].y-a[num<<1].x+1)%mod*a[num].add%mod)%mod;
    a[num<<1|1].sum=(a[num<<1|1].sum+(a[num<<1|1].y-a[num<<1|1].x+1)%mod*a[num].add%mod)%mod;
    a[num].add=0;
}
void build(int dx,int dy,int num)
{
    a[num].x=dx,a[num].y=dy;
    if(dx==dy) return ;
    int mid=(dx+dy)>>1;
    build(dx,mid,num<<1);
    build(mid+1,dy,num<<1|1);
    return ;
}   
void update(int dx,int dy,int c,int num)
{
    if(a[num].x==dx&a[num].y==dy)
    {
        a[num].add=(a[num].add+c)%mod;
        a[num].sum=(a[num].sum%mod+(a[num].y-a[num].x+1)%mod*c%mod)%mod;
        return;
    }
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dx>mid) update(dx,dy,c,num<<1|1);
    else if(dy<=mid) update(dx,dy,c,num<<1);
    else update(dx,mid,c,num<<1),update(mid+1,dy,c,num<<1|1);
    a[num].sum=(a[num<<1].sum+a[num<<1|1].sum)%mod;
}
int ask(int dx,int dy,int num)
{
    if(a[num].x==dx&a[num].y==dy)
        return a[num].sum%mod;
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dx>mid) return ask(dx,dy,num<<1|1)%mod;
    if(dy<=mid) return ask(dx,dy,num<<1)%mod;
    return (ask(dx,mid,num<<1)+ask(mid+1,dy,num<<1|1))%mod;
}
void change(int x,int y,int c)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        update(number[top[x]],number[x],c,1); 
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    update(number[x],number[y],c,1);
}
int find(int x,int y)
{
    int sum=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        sum=(sum+ask(number[top[x]],number[x],1))%mod;
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    sum=(sum+ask(number[x],number[y],1))%mod;
    return sum;
}
int cmp(way j,way k)
{
    return j.x<k.x;
}
int now=1;
int ans[60000];
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=2;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        x++;
        add(i,x),add(x,i);
    }
    for(int i=1;i<=m;i++)
    {
        int l,r,z;
        scanf("%d%d%d",&l,&r,&z);
        z++;
        q[(i<<1)-1].x=l,q[(i<<1)-1].flag=0,q[(i<<1)-1].c=z,q[(i<<1)-1].id=i;
        q[i<<1].x=r+1,q[i<<1].flag=1,q[i<<1].c=z,q[i<<1].id=i;
    }
    dfs(1,0),dfs2(1,1);
    build(1,n,1);
    sort(q+1,q+2*m+1,cmp);
    for(int i=1;i<=2*m;i++)
    {
        while(now<=q[i].x) change(1,now,1),now++;
        if(!q[i].flag) ans[q[i].id]-=find(1,q[i].c);
        else ans[q[i].id]=(ans[q[i].id]+mod+find(1,q[i].c))%mod;
    }
    for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
    return 0;
}

 

发表评论

邮箱地址不会被公开。 必填项已用*标注