调了一天,请各位大佬帮忙,感激不尽。
评测记录 从LOJ下数据分析:
TLE:爆栈(树剖dfs1时)
WA:几次询问答案小1或2(怀疑是跳到同一链时合并有误)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int uint;
const int MAX_V=1e5+10;
vector<int> G[MAX_V];
void addedge(int u,int v)
{
G[u].push_back(v);
G[v].push_back(u);
}
//使用结构体进行两段的合并
struct Result{int l,r,cnt;};//cnt<=0为空线段
Result merge(Result a,Result b)
{
if(a.cnt<=0)return b;
if(b.cnt<=0)return a;
Result r{a.l,b.r,a.cnt+b.cnt};
if(a.r==b.l)r.cnt--;//相同则减一
return r;
}
//线段树部分
uint a,b;
int color,N;
Result dat[262144];
void init(int n,int color[])
{
N=1;while(N<n)N<<=1;
for(int i=0;i<n;i++)dat[i+N-1]=Result{color[i],color[i],1};
for(int i=N-2;i>=0;i--)dat[i]=merge(dat[i*2+1],dat[i*2+2]);
}
void change(uint k,uint l,uint r)
{
if(b<=l||a>=r)return;
if(a<=l&&r<=b){
dat[k]=Result{color,color,1};
return;
}
if(dat[k].cnt==1)dat[k*2+1]=dat[k*2+2]=dat[k];/相当于pushdown
change(k*2+1,l,(l+r)>>1);
change(k*2+2,(l+r)>>1,r);
dat[k]=merge(dat[k*2+1],dat[k*2+2]);//pushup
}
void change(uint l,uint r)
{
a=l;b=r+1;
change(0,0,N);
}
Result query(uint k,uint l,uint r)
{
if(b<=l||a>=r)return Result{0,0,0};//空段返回cnt=0
if(a<=l&&r<=b)return dat[k];
if(dat[k].cnt==1)return dat[k];
return merge(query(k*2+1,l,(l+r)>>1),query(k*2+2,(l+r)>>1,r));
}
Result query(uint l,uint r)
{
assert(l>=0);
a=l;b=r+1;
return query(0,0,N);
}
//树剖部分
int father[MAX_V],depth[MAX_V],siz[MAX_V],son[MAX_V],top[MAX_V],seg[MAX_V],segsize;
void dfs1(int u,int p)
{
father[u]=p;
depth[u]=depth[p]+1;
siz[u]=1;
int s=0;
for(int v:G[u])
if(v!=p)
{
dfs1(v,u);
siz[u]+=siz[v];
if(siz[s]<siz[v])s=v;
}
son[u]=s;
}
void dfs2(int u)
{
seg[u]=segsize++;
if(son[u])
{
top[son[u]]=top[u];
dfs2(son[u]);
}
for(int v:G[u])
if(top[v]<0)
{
top[v]=v;
dfs2(v);
}
}
int readint()
{
register int x=0;
register char c=getchar();
while(c<'-')c=getchar();
for(;c>='0'&&c<='9';c=getchar())
x=x*10+c-'0';
return x;
}
int color1[MAX_V],color2[MAX_V];
int main()
{
#ifndef ONLINE_JUDGE
freopen("stdin.txt", "r", stdin);
#endif
//初始化
memset(top,-1,sizeof(top));
int n=readint(),m=readint();
for(int i=0;i<n;i++)color1[i]=readint();
for(int i=1;i<n;i++)addedge(readint()-1,readint()-1);
top[0]=0;
dfs1(0,-1);
dfs2(0);
for(int i=0;i<n;i++)color2[seg[i]]=color1[i];
init(n,color2);
//处理操作
while(m--)
{
char c=getchar();
while(c<'-')c=getchar();
int u=readint()-1,v=readint()-1;
if(c=='C'){//更改操作
color=readint();
while(top[u]!=top[v])
{
if(depth[top[u]]<depth[top[v]])swap(u,v);
change(seg[top[u]],seg[u]);
u=father[top[u]];
}
if(depth[u]>depth[v])swap(u,v);
change(seg[u],seg[v]);
}else{//查询操作
Result ru=Result{0,0,0},rv=ru;
while(top[u]!=top[v])
{
if(depth[top[u]]<depth[top[v]])swap(u,v),swap(ru,rv);
ru=merge(query(seg[top[u]],seg[u]),ru);
u=father[top[u]];
}
if(u!=v){
if(depth[u]>depth[v])swap(u,v),swap(ru,rv);
rv=merge(query(seg[u],seg[v]),rv);
}
int ans=ru.cnt+rv.cnt;
if(ru.l==rv.l)ans--;//相同则减一
printf("%d\n",ans);
}
}
return 0;
}