RT,大致思路就是树剖+线段树维护区间覆盖,但不知道哪里挂了QAQ
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 400100
#define int long long
int n,m,cnt,tot;
int head[N],siz[N],dep[N],dfn[N],tp[N],fa[N],son[N],rk[N],num[N];
struct Edge{
int nxt,to;
}edge[N*2];
struct Tree{
int l,r,c,lz;
}tree[N*4];
void add(int from,int to)
{
edge[++cnt].nxt=head[from];
edge[cnt].to=to;
head[from]=cnt;
}
void dfs1(int s,int pre)
{
fa[s]=pre;
siz[s]=1;
for(int i=head[s];i;i=edge[i].nxt)
{
int nxt=edge[i].to;
if(nxt==pre)continue;
dfs1(nxt,s);
siz[s]+=siz[nxt];
if(siz[nxt]>siz[son[s]])son[s]=nxt;
}
}
void dfs2(int s,int top)
{
dfn[s]=++tot;
rk[tot]=s;
tp[s]=top;
if(son[s])dfs2(son[s],top);
for(int i=head[s];i;i=edge[i].nxt)
{
int nxt=edge[i].to;
if(nxt==fa[s]||nxt==son[s])continue;
dfs2(nxt,nxt);
}
}
void pushup(int id)
{
tree[id].c=tree[id*2].c|tree[id*2+1].c;
}
void pushdown(int id)
{
if(!tree[id].lz)return;
tree[id*2].lz=tree[id].lz;
tree[id*2+1].lz=tree[id].lz;
tree[id*2].c=tree[id].lz;
tree[id*2+1].c=tree[id].lz;
tree[id].lz=0;
}
void build(int id,int l,int r)
{
tree[id].l=l;tree[id].r=r;
if(l==r)
{
int cur=num[rk[l]];
tree[id].c=1ll<<cur;
return;
}
int mid=(l+r)/2;
build(id*2,l,mid);
build(id*2+1,mid+1,r);
pushup(id);
}
void change(int id,int l,int r,int k)
{
if(l<=tree[id].l&&tree[id].r<=r)
{
tree[id].c=1ll<<k;
tree[id].lz=1ll<<k;
return;
}
pushdown(id);
int mid=(tree[id].l+tree[id].r)/2;
if(l<=mid)change(id*2,l,r,k);
if(r>mid)change(id*2+1,l,r,k);
pushup(id);
}
int find(int id,int l,int r)
{
if(l<=tree[id].l&&tree[id].r<=r)return tree[id].c;
int ans=0;
int mid=(tree[id].l+tree[id].r)/2;
pushdown(id);
if(l<=mid)ans|=find(id*2,l,r);
if(r>mid)ans|=find(id*2+1,l,r);
return ans;
}
signed main()
{
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;++i)scanf("%lld",&num[i]);
for(int i=1;i<n;++i)
{
int x,y;
scanf("%lld%lld",&x,&y);
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;++i)
{
int k,x,y,z;
scanf("%lld",&k);
if(k==1)
{
scanf("%lld%lld",&x,&y);
change(1,dfn[x],dfn[x]+siz[x]-1,y);
}
else if(k==2)
{
scanf("%lld",&x);
int qwq=find(1,dfn[x],dfn[x]+siz[x]-1);
int cnt=0;
while(qwq)
{
if(qwq&1)++cnt;
qwq/=2;
}
printf("%lld\n",cnt);
}
}
return 0;
}