HDU 5893 List wants to travel(树链剖分)

forever97 posted @ 2016年9月18日 23:12 in 数据结构-树链剖分 with tags 树链剖分 线段树 , 673 阅读

 

【题目链接】 http://acm.hdu.edu.cn/showproblem.php?pid=5893

 

【题目大意】

   给出一棵树,每条边上都有一个边权,现在有两个操作,操作一要求将x到y路径上所有边更改为c权值,操作二要求查询x到y的路径上有几段连续的权值相同的。

 

【题解】

   首先由于是边权,所以把所有边的存下来,做一遍剖分,将权值保存在每条边深度较深的点上,作为点权,用区间合并线段树维护区间内的线段段数,沿链修改的时候注意剖分出的区间的起点是不更新的,因为边权变成点权之后链修改的LCA是不修改的。查询的时候由于边权转点权之后点权位置的特殊性,我们每次在查询a到b之间的答案的时候,首先求出两者的LCA,同时求出LCA到a和b之间的第一个点,求分别求出a和b与其第二root之间的答案,再判断一下交接处的情况,就能计算出答案。

 

【代码】

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=1000000; 
int tot,x,d[N],num[N],ed=0,u,w,n,m,i,v[N],vis[N],f[N],g[N],nxt[N],size[N],son[N],st[N],en[N],dfn,top[N],t;char ch;
void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
void dfs(int x){
    size[x]=1;
    for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){
        f[v[i]]=x,d[v[i]]=d[x]+1;
        dfs(v[i]),size[x]+=size[v[i]];
        if(size[v[i]]>size[son[x]])son[x]=v[i];
    }
}
void dfs2(int x,int y){
    if(x==-1)return;
    st[x]=++dfn;top[x]=y;
    if(son[x])dfs2(son[x],y);
    for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]);
    en[x]=dfn;
}
int T[N<<2],mark[N<<2],cl[N<<2],cr[N<<2],L,R;
void up(int x){
    T[x]=T[x<<1]+T[x<<1|1]-(cr[x<<1]==cl[x<<1|1]);
    cl[x]=cl[x<<1];
    cr[x]=cr[x<<1|1];
}
void pushdown(int x,int l,int r){
    if(l==r)return;
    if(mark[x]!=-1){
        mark[x<<1]=mark[x<<1|1]=mark[x];
        cl[x<<1]=cl[x<<1|1]=mark[x];
        cr[x<<1]=cr[x<<1|1]=mark[x];
        T[x<<1]=T[x<<1|1]=1;
        mark[x]=-1;
    }
}
void update(int x,int l,int r,int c){
    pushdown(x,l,r);
    if(L<=l&&r<=R){T[x]=1;mark[x]=cl[x]=cr[x]=c;return;}
    int mid=(l+r)/2;
    if(L<=mid)update(x<<1,l,mid,c);
    if(mid<R)update(x<<1|1,mid+1,r,c);
    up(x);
}
void update(int l,int r,int c){
    if(l>r)return;
    L=l;R=r; update(1,1,n,c);
}
int query(int x,int l,int r){
    pushdown(x,l,r);
    if(L<=l&&r<=R)return T[x];
    int mid=(l+r)/2,ret=0;
    if(L<=mid)ret+=query(x<<1,l,mid);
    if(mid<R)ret+=query(x<<1|1,mid+1,r);
    if(L<=mid&&mid<R)ret-=(cr[x<<1]==cl[x<<1|1]);
    return ret;
}
int query(int l,int r){L=l;R=r;return query(1,1,n);}
int color(int x,int l,int r,int f){
    if(l==r)return cl[x];
	  pushdown(x,l,r);
	  int mid=(l+r)/2;
	  if(f<=mid)return color(x<<1,l,mid,f);
	  return color(x<<1|1,mid+1,r,f);
}
int query(int l){return color(1,1,n,l);}
void chain(int x,int y,int c){
    for(;top[x]!=top[y];x=f[top[x]]){
        if(d[top[x]]<d[top[y]]){int z=x;x=y;y=z;}
        update(st[top[x]],st[x],c);
    }if(d[x]<d[y]){int z=x;x=y;y=z;}
    update(st[y]+1,st[x],c);
}
int find(int x,int y){
	  int ret=0;
    for(;top[x]!=top[y];x=f[top[x]]){
        if(d[top[x]]<d[top[y]]){int z=x;x=y;y=z;}
        ret+=query(st[top[x]],st[x]);
        ret-=(query(st[top[x]])==query(st[f[top[x]]]));
    }if(d[x]<d[y]){int z=x;x=y;y=z;}
    ret+=query(st[y],st[x]);
    return ret;
}
int lca(int x,int y){
    for(;top[x]!=top[y];x=f[top[x]])if(d[top[x]]<d[top[y]]){int z=x;x=y;y=z;}
    return d[x]<d[y]?x:y;
}
int lca2(int x,int y){
    int t;
    while(top[x]!=top[y])t=top[y],y=f[top[y]];
    return x==y?t:son[x];
}
void init(){ 
    for(int i=0;i<n*4;i++)T[i]=1,mark[i]=-1;
    memset(g,dfn=ed=0,sizeof(g));
    memset(v,0,sizeof(v));
    memset(nxt,0,sizeof(nxt));
    memset(son,-1,sizeof(son));
}
int cas;
int e[N][3];
int main(){
    while(~scanf("%d%d",&n,&m)){
        init();
        for(int i=0;i<n-1;i++){
            scanf("%d%d%d",&e[i][0],&e[i][1],&e[i][2]);
            add(e[i][0],e[i][1]);
            add(e[i][1],e[i][0]);
        }dfs(1);dfs2(1,1);
        for(int i=0;i<n-1;i++){
            if(d[e[i][0]]>d[e[i][1]])swap(e[i][0],e[i][1]);
            update(st[e[i][1]],st[e[i][1]],e[i][2]);
        }char op[10]; int a,b,c;
        //for(int i=1;i<=n;i++)printf("%d\n",st[i]);
        while(m--){
            scanf("%s",op);
            scanf("%d%d",&a,&b);
            if(op[0]=='Q'){
                c=lca(a,b);
                int fa=lca2(c,a);
                int fb=lca2(c,b);
                //printf("%d %d %d\n",c,fa,fb);
                //printf("%d %d\n",query(st[fa]),query(st[fb])); 
                if(a==b)puts("0");
                else if(c==a)printf("%d\n",find(b,fb));
                else if(c==b)printf("%d\n",find(a,fa));
                else printf("%d\n",find(a,fa)+find(b,fb)-(query(st[fa])==query(st[fb])));
            }else{
                scanf("%d",&c);
                chain(a,b,c);
            }
        }
    }return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter