[BZOJ2243] [SDOI2011]染色

题目描述

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示xy之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

题目分析:

裸树剖+线段树区间合并 合并时需要判断一下左儿子的右端点和右儿子的左端点是否相同
并且在查找时候还要注意讨论color[fa[top[x]]]和color[top[x]]是否相同
一开始点权写错了竟然还能过样例QAQ

#include <cstdio>
#include <cstring>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
int n,m;
int net[500010],head[500010],to[500010];
int val[500010],top[500010],number[500010],v[500010];
int fa[500010],size[500010],son[500010],deep[500010];
int tot,cnt;
struct your
{
    int x,y;
    int color;
    int l,r;
    int add;
}a[500000];
void add(int x,int y)
{
    net[++tot]=head[x];
    to[tot]=y;
    head[x]=tot;
}
void dfs(int x,int temp)
{
    fa[x]=temp;
    deep[x]=deep[temp]+1;
    size[x]=1;
    for(int i=head[x];i;i=net[i])
    {
        if(to[i]==temp) continue;
        dfs(to[i],x);
        size[x]+=size[to[i]];
        if(size[son[x]]<size[to[i]]) son[x]=to[i];
    }
}
void dfs2(int x,int temp)
{
    number[x]=++cnt;
    top[x]=temp;
    val[number[x]]=v[x];
    if(son[x]) dfs2(son[x],temp);
    for(int i=head[x];i;i=net[i])
        if(to[i]!=son[x]&&to[i]!=fa[x])
            dfs2(to[i],to[i]);
}
void build(int dx,int dy,int num)
{
    a[num].x=dx,a[num].y=dy;
    if(dx==dy)
    {
        a[num].color=1;
        a[num].l=a[num].r=val[dx];
        return ;
    }
    int mid=(dx+dy)>>1;
    build(dx,mid,num<<1);
    build(mid+1,dy,num<<1|1);
    a[num].l=a[num<<1].l,a[num].r=a[num<<1|1].r;
    int tmp=a[num<<1].color+a[num<<1|1].color;
    a[num].color=(a[num<<1].r==a[num<<1|1].l)?tmp-1:tmp;
}
void pushdown(int num)
{
    a[num<<1].add=a[num<<1|1].add=a[num].add;
    a[num<<1].color=a[num<<1|1].color=1;
    a[num<<1].l=a[num<<1].r=a[num].add;
    a[num<<1|1].l=a[num<<1|1].r=a[num].add;
    a[num].add=0;
}
void update(int dx,int dy,int c,int num)
{
    if(a[num].x==dx&&a[num].y==dy)
    {
        a[num].color=1;
        a[num].add=c;
        a[num].l=a[num].r=c;
        return ;
    }
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dx>mid) update(dx,dy,c,num<<1|1);
    else if(dy<=mid) update(dx,dy,c,num<<1);
    else update(dx,mid,c,num<<1),update(mid+1,dy,c,num<<1|1);
    a[num].l=a[num<<1].l,a[num].r=a[num<<1|1].r;
    a[num].color=a[num<<1].color+a[num<<1|1].color-(a[num<<1].r==a[num<<1|1].l);
}
void change(int x,int y,int c)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        update(number[top[x]],number[x],c,1);
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    update(number[x],number[y],c,1);
}
int ask(int dx,int dy,int num)
{
    if(a[num].x==dx&&a[num].y==dy) return a[num].color;
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dx>mid) return ask(dx,dy,num<<1|1);
    if(dy<=mid) return ask(dx,dy,num<<1);
    return ask(dx,mid,num<<1)+ask(mid+1,dy,num<<1|1)-(a[num<<1].r==a[num<<1|1].l);
}
int findcolor(int dx,int num)
{
    if(a[num].x==dx&&a[num].y==dx) return a[num].l;
    if(a[num].add) pushdown(num);
    int mid=(a[num].x+a[num].y)>>1;
    if(dx<=mid) return findcolor(dx,num<<1);
    return findcolor(dx,num<<1|1);
}
int find(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        ans+=ask(number[top[x]],number[x],1);
        if(findcolor(number[top[x]],1)==findcolor(number[fa[top[x]]],1)) ans--;
        x=fa[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    ans+=ask(number[x],number[y],1);
    return ans;
}
char s[5];
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&v[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    dfs(1,0);
    dfs2(1,1);
    build(1,n,1);
    for(int i=1;i<=m;i++)
    {
        scanf("%s",&s[0]);
        if(s[0]=='C')
        {
            int x,y,c;
            scanf("%d%d%d",&x,&y,&c);
            change(x,y,c);
        }
        else 
        {
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",find(x,y));
        }
    }
}

发表评论

邮箱地址不会被公开。 必填项已用*标注