蒟蒻求助
查看原帖
蒟蒻求助
159959
虫洞吞噬者楼主2022/3/9 10:55

RT,大致思路就是树剖+线段树维护区间覆盖,但不知道哪里挂了QAQ

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 400100
#define int long long
int n,m,cnt,tot;
int head[N],siz[N],dep[N],dfn[N],tp[N],fa[N],son[N],rk[N],num[N];
struct Edge{
	int nxt,to;
}edge[N*2];
struct Tree{
	int l,r,c,lz;
}tree[N*4];
void add(int from,int to)
{
	edge[++cnt].nxt=head[from];
	edge[cnt].to=to;
	head[from]=cnt;
}
void dfs1(int s,int pre)
{
	fa[s]=pre;
	siz[s]=1;
	for(int i=head[s];i;i=edge[i].nxt)
	{
		int nxt=edge[i].to;
		if(nxt==pre)continue;
		dfs1(nxt,s);
		siz[s]+=siz[nxt];
		if(siz[nxt]>siz[son[s]])son[s]=nxt;
	}
}
void dfs2(int s,int top)
{
	dfn[s]=++tot;
	rk[tot]=s;
	tp[s]=top;
	if(son[s])dfs2(son[s],top);
	for(int i=head[s];i;i=edge[i].nxt)
	{
		int nxt=edge[i].to;
		if(nxt==fa[s]||nxt==son[s])continue;
		dfs2(nxt,nxt);
	}
}
void pushup(int id)
{
	tree[id].c=tree[id*2].c|tree[id*2+1].c;
}
void pushdown(int id)
{
	if(!tree[id].lz)return;
	tree[id*2].lz=tree[id].lz;
	tree[id*2+1].lz=tree[id].lz;
	tree[id*2].c=tree[id].lz;
	tree[id*2+1].c=tree[id].lz;
	tree[id].lz=0;
}
void build(int id,int l,int r)
{
	tree[id].l=l;tree[id].r=r;
	if(l==r)
	{
		int cur=num[rk[l]];
		tree[id].c=1ll<<cur;
		return;
	}
	int mid=(l+r)/2;
	build(id*2,l,mid);
	build(id*2+1,mid+1,r);
	pushup(id);
}
void change(int id,int l,int r,int k)
{
	if(l<=tree[id].l&&tree[id].r<=r)
	{
		tree[id].c=1ll<<k;
		tree[id].lz=1ll<<k;
		return;
	}
	pushdown(id);
	int mid=(tree[id].l+tree[id].r)/2;
	if(l<=mid)change(id*2,l,r,k);
	if(r>mid)change(id*2+1,l,r,k);
	pushup(id);
}
int find(int id,int l,int r)
{
	if(l<=tree[id].l&&tree[id].r<=r)return tree[id].c;
	int ans=0;
	int mid=(tree[id].l+tree[id].r)/2;
	pushdown(id);
	if(l<=mid)ans|=find(id*2,l,r);
	if(r>mid)ans|=find(id*2+1,l,r);
	return ans;
}
signed main()
{
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n;++i)scanf("%lld",&num[i]);
	for(int i=1;i<n;++i)
	{
		int x,y;
		scanf("%lld%lld",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(1,0);
	dfs2(1,1);
	build(1,1,n);
	for(int i=1;i<=m;++i)
	{
		int k,x,y,z;
		scanf("%lld",&k);
		if(k==1)
		{
			scanf("%lld%lld",&x,&y);
			change(1,dfn[x],dfn[x]+siz[x]-1,y);
		}
		else if(k==2)
		{
			scanf("%lld",&x);
			int qwq=find(1,dfn[x],dfn[x]+siz[x]-1);
			int cnt=0;
			while(qwq)
			{
				if(qwq&1)++cnt;
				qwq/=2;
			}
			printf("%lld\n",cnt);
		}
	}
	return 0;
}
2022/3/9 10:55
加载中...