为什么求答案的get_t函数里必须用这种带LCA的办法,而用注释里树剖模板的方式去求会炸呢?
#include<bits/stdc++.h>
using namespace std;
const int MAXN=100005;
struct wzy{
int to;
int nxt;
}edge[MAXN<<1];
struct xjc{
int l;
int r;
int lval;
int rval;
int len;
int ans;
int tag;
}t[MAXN<<2];
int ccb,n,m,cnt=0,head[MAXN];
int dep[MAXN],dfn[MAXN],rnk[MAXN],fa[MAXN],size[MAXN],top[MAXN],hson[MAXN];
xjc zero(xjc x){
x.ans=x.l=x.len=x.lval=x.r=x.rval=x.tag=0;
return x;
}
void build114514(int a,int b){
edge[++cnt].to=b;
edge[cnt].nxt=head[a];
head[a]=cnt;
return ;
}
xjc merge(xjc a,xjc b){
xjc num;
num.lval=a.lval;
num.rval=b.rval;
num.ans=a.ans+b.ans;
if(a.rval==b.lval) num.ans++;
num.len=a.len+b.len;
return num;
}
void pushup(int p){
t[p]=merge(t[p<<1],t[p<<1|1]);
t[p].l=t[p<<1].l;
t[p].r=t[p<<1|1].r;
t[p].tag=0;
return ;
}
void add(int p,int h){
t[p].ans=t[p].len-1;
t[p].lval=h;
t[p].rval=h;
t[p].tag=h;
return ;
}
void pushdown(int p){
if(t[p].tag==0) return ;
add(p<<1,t[p].tag);
add(p<<1|1,t[p].tag);
t[p].tag=0;
return ;
}
void dfs1(int u,int f){
dep[u]=dep[f]+1;
size[u]=1;
fa[u]=f;
int maxnum=0;
for(int i=head[u];i!=-1;i=edge[i].nxt){
int v=edge[i].to;
if(v==f) continue;
dfs1(v,u);
if(size[v]>maxnum){
maxnum=size[v];
hson[u]=v;
}
size[u]+=size[v];
}
return ;
}
void dfs2(int u,int topf){
dfn[u]=++cnt;
rnk[cnt]=u;
top[u]=topf;
if(hson[u]==0) return ;
dfs2(hson[u],topf);
for(int i=head[u];i!=-1;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa[u]||v==hson[u]) continue;
dfs2(v,v);
}
return ;
}
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y]) return x;
else return y;
}
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
t[p].len=r-l+1;
t[p].ans=0;
t[p].tag=0;
if(l==r){
t[p].lval=l;
t[p].rval=l;
return ;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
pushup(p);
return ;
}
void change(int p,int L,int R,int h){
if(t[p].l>=L&&t[p].r<=R){
add(p,h);
return ;
}
pushdown(p);
int mid=(t[p].l+t[p].r)>>1;
if(L<=mid) change(p<<1,L,R,h);
if(R>mid) change(p<<1|1,L,R,h);
pushup(p);
return ;
}
xjc get(int p,int L,int R){
if(t[p].l>=L&&t[p].r<=R){
return t[p];
}
pushdown(p);
int mid=(t[p].l+t[p].r)>>1;
if(R<=mid) return get(p<<1,L,R);
if(L>mid) return get(p<<1|1,L,R);
return merge(get(p<<1,L,R),get(p<<1|1,L,R));
}
void change_t(int x,int y,int h){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
change(1,dfn[top[x]],dfn[x],h);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
change(1,dfn[x],dfn[y],h);
return ;
}
int get_t(int x,int y){
xjc al,al2;
al=zero(al);
al2=zero(al2);
int man=lca(x,y);
/*while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
al=merge(get(1,dfn[top[x]],dfn[x]),al);
al=merge(get(1,dfn[fa[top[x]]],dfn[top[x]]),al);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
if(x!=y) al=merge(get(1,dfn[x],dfn[y]),al);*/
while(top[x]!=top[man]){
al=merge(get(1,dfn[top[x]],dfn[x]),al);
x=fa[top[x]];
}
al=merge(get(1,dfn[man],dfn[x]),al);
while(top[y]!=top[man]){
al2=merge(get(1,dfn[top[y]],dfn[y]),al2);
y=fa[top[y]];
}
al2=merge(get(1,dfn[man],dfn[y]),al2);
return al.ans+al2.ans;
}
int main(){
scanf("%d",&ccb);
for(int z=1;z<=ccb;z++){
memset(head,-1,sizeof(head));
memset(edge,0,sizeof(edge));
//memset(top,0,sizeof(top));
memset(hson,0,sizeof(hson));
scanf("%d %d",&n,&m);
cnt=0;
for(int i=1,u,v;i<=n-1;i++){
scanf("%d %d",&u,&v);
build114514(u,v);
build114514(v,u);
}
dfs1(1,0);
cnt=0;
dfs2(1,1);
build(1,1,n);
for(int i=1,op,a,b;i<=m;i++){
scanf("%d %d %d",&op,&a,&b);
if(op==1){
change_t(a,b,i+n);
}
if(op==2) printf("%d\n",get_t(a,b));
}
}
return 0;
}