P4074不过阳历求条
查看原帖
P4074不过阳历求条
1601965
fangtongsheng楼主2025/2/5 16:54

已知问题可能出现在时间更新上

含有调试

#include<bits/stdc++.h> 
#define N 400400
using namespace std;
int n,m,qw,Qt;
vector<int> R[N];
int v[N],w[N],c[N];
int s[N],e[N],t[N],st[N][22],dep[N],cnt=0;
struct point {
	int l,r,t,id,lca;
}q[N],p[N];int pl=0,ql=0;
bool cmp(point aa,point bb){
	if(aa.l/Qt!=bb.l/Qt) return aa.l<bb.l;
	if(aa.r/Qt!=bb.r/Qt) return aa.r<bb.r;
	return aa.t<bb.t;
}
int hap[N];
int used[N];
long long answer=0,ans[N];
void add(int pos){
	pos=t[pos];
	if(used[pos]==0){
		hap[c[pos]]++;
		answer+=(long long)w[hap[c[pos]]]*v[c[pos]];
	}
	else {
		answer-=(long long)w[hap[c[pos]]]*v[c[pos]];
		hap[c[pos]]--;
	}
	used[pos]++;
}
void del(int pos){
	pos=t[pos];
	if(used[pos]==1){
		answer-=(long long)w[hap[c[pos]]]*v[c[pos]];
		hap[c[pos]]--;
	}
	else {
		hap[c[pos]]++;
		answer+=(long long)w[hap[c[pos]]]*v[c[pos]];
	}
	used[pos]--;
}
void upd(int pos,int tm) { 
	if(used[p[tm].l]==1) {
		del(p[tm].l);
		swap(c[p[tm].l],p[tm].r);
		add(p[tm].l);
	}
	else swap(c[p[tm].l],p[tm].r);
}
void dfs(int u,int f){
	st[u][0]=f,dep[u]=dep[f]+1;
	for(int i=1;i<=20;i++){
		int p=st[u][i-1];
		if(!p) break;
		st[u][i]=st[p][i-1];
	}
	s[u]=++cnt;
	for(int i=0;i<R[u].size();i++){
		int to=R[u][i];
		if(to==f) continue;
		dfs(to,u);
	}
	e[u]=++cnt;
	t[e[u]]=t[s[u]]=u;
}
int LCA(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=20;i>=0;i--) if(dep[u]-(1<<i)>=dep[v]) u=st[u][i];
	if(u==v) return u;
	for(int i=20;i>=0;i--) if(st[u][i]&&st[v][i]&&st[u][i]!=st[v][i])
		u=st[u][i],v=st[v][i];
	return st[u][0];
}
int main(){
	cin>>n>>m>>qw;
	Qt=pow(n,0.666);
	for(int i=1;i<=m;i++) scanf("%d",&v[i]);
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		R[u].push_back(v),R[v].push_back(u);
	}
	for(int i=1;i<=n;i++) scanf("%d",&c[i]);
	dfs(1,cnt=0);
	for(int i=1;i<=qw;i++){
		int ty,u,v;
		scanf("%d%d%d",&ty,&u,&v);
		if(ty==0) {
			p[++pl]=(point){u,v,0,0,0};
		} else {
			q[++ql]=(point){u,v,pl,ql,0};
		}
	}
	for(int i=1;i<=ql;i++){
		q[i].lca=LCA(q[i].l,q[i].r);
	//	printf("%d %d => %d \n",q[i].l,q[i].r,q[i].lca);
		int u=q[i].l,v=q[i].r;
		if(s[u]>s[v]) swap(u,v);
		u=e[u],v=s[v];
		q[i].l=u,q[i].r=v;
	}
	sort(q+1,q+1+ql,cmp);
	int l=1,r=0,tm=0;
	for(int i=1;i<=ql;i++)
	{
		printf("### %d(%d) %d(%d) %d ###\n",q[i].l,t[q[i].l],q[i].r,t[q[i].r],q[i].t);
		while(l>q[i].l) add(--l);printf("ans:%d \n",answer);
		while(r<q[i].r) add(++r);printf("ans:%d \n",answer);
		while(l<q[i].l) del(l++);printf("ans:%d \n",answer);
		while(r>q[i].r) del(r--);printf("ans:%d \n",answer);
		while(tm>q[i].t) upd(i,tm--);printf("ans:%d \n",answer);
		while(r<q[i].t) upd(i,++tm);printf("ans:%d \n",answer);
		add(s[q[i].lca]);
		ans[q[i].id]=answer;
		printf("ANSWER:%d \n",answer);
		del(s[q[i].lca]);
		puts("****");
	}
	for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
	return 0;
}

2025/2/5 16:54
加载中...