Description
给您一颗树,每个节点有个初始值。
现在支持以下两种操作:
1. C i x(0<=x< 2^31) 表示将i节点的值改为x。
2. Q i j x(0<=x< 2^31) 表示询问i节点到j节点的路径上有多少个值为x的节点。
Input
第一行有两个整数N,Q(1 ≤N≤ 100,000;1 ≤Q≤ 200,000),分别表示节点个数和操作个数。
下面一行N个整数,表示初始时每个节点的初始值。
接下来N-1行,每行两个整数x,y,表示x节点与y节点之间有边直接相连(描述一颗树)。
接下来Q行,每行表示一个操作,操作的描述已经在题目描述中给出。
Output
对于每个Q输出单独一行表示所求的答案。
Sample Input
5 6
10 20 30 40 50
1 2
1 3
3 4
3 5
Q 2 3 40
C 1 40
Q 2 3 40
Q 4 5 30
C 3 10
Q 4 5 30
Sample Output
0
1
1
0
题目分析
树剖+动态开点线段树
对于每个权值x开一颗线段树 维护1~n区间
使用map来记录每个root[]的下标
#include <cstdio>
#include <cstring>
#include <set>
#include <map>
#include <vector>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
int n,m;
int val[100100];
int tot;
int head[100100],to[300300],net[300300];
void add(int x,int y)
{
net[++tot]=head[x],head[x]=tot,to[tot]=y;
}
int deep[100100],fa[100100],size[100100];
int top[100100],son[100100],number[100100];
void dfs(int x)
{
deep[x]=deep[fa[x]]+1;
size[x]=1;
for(int i=head[x];i;i=net[i])
if(to[i]!=fa[x])
{
fa[to[i]]=x,dfs(to[i]);
if(size[son[x]]<size[to[i]]) son[x]=to[i];
size[x]+=size[to[i]];
}
}
int cnt;
void dfs2(int x,int temp)
{
top[x]=temp,number[x]=++cnt;
if(son[x]) dfs2(son[x],temp);
for(int i=head[x];i;i=net[i])
if(to[i]!=fa[x]&&to[i]!=son[x])
dfs2(to[i],to[i]);
}
struct your
{
int lson,rson;
int size;
}a[100010*50];
map<int,int>mp;
int nm,color;
int rooot[100100];
void update(int dx,int l,int r,int c,int &x)
{
if(!x) x=++nm;
if(l==r)
{
a[x].size+=c;
return ;
}
int mid=(l+r)>>1;
if(dx>mid) update(dx,mid+1,r,c,a[x].rson);
else update(dx,l,mid,c,a[x].lson);
a[x].size=a[a[x].lson].size+a[a[x].rson].size;
}
int ask(int dx,int dy,int l,int r,int x)
{
if(dx<=l&&r<=dy) return a[x].size;
int mid=(l+r)>>1;
if(dx>mid) return ask(dx,dy,mid+1,r,a[x].rson);
else if(dy<=mid) return ask(dx,dy,l,mid,a[x].lson);
else return ask(dx,mid,l,mid,a[x].lson)+ask(mid+1,dy,mid+1,r,a[x].rson);
}
int find(int x,int y,int c)
{
int sum=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
sum+=ask(number[top[x]],number[x],1,n,rooot[mp[c]]);
x=fa[top[x]];
}
if(deep[x]>deep[y]) swap(x,y);
sum+=ask(number[x],number[y],1,n,rooot[mp[c]]);
return sum;
}
char s[10];
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int x,y,i=1;i<n;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
dfs(1),dfs2(1,1);
for(int i=1;i<=n;i++)
{
if(!mp[val[i]]) mp[val[i]]=++color;
update(number[i],1,n,1,rooot[mp[val[i]]]);
}
for(int x,y,c,i=1;i<=m;i++)
{
scanf("%s",&s[0]);
if(s[0]=='C')
{
scanf("%d%d",&x,&c);
update(number[x],1,n,-1,rooot[mp[val[x]]]);
if(!mp[c]) mp[c]=++color;
val[x]=c;
update(number[x],1,n,1,rooot[mp[val[x]]]);
}
else
{
scanf("%d%d%d",&x,&y,&c);
if(!mp[c]) printf("0\n");
else printf("%d\n",find(x,y,c));
}
}
return 0;
}