50pts求助
查看原帖
50pts求助
365164
KR_01楼主2021/7/15 08:21
#include<bits/stdc++.h>
using namespace std;
const int mo=998244353;
#define int long long
int n,m,a[100100];
struct hs1{
	int p;
	int v;
}h1[100100];
int h2[100100];
struct hs3{
	int sum;
	vector<int>q;
}h3[100100];
int dy[100100],q[100100];
int cj[100100],sum[100100],deep[100100];
vector<int>ma[100100];
int dfs(int pl){
	if(cj[pl]!=-1){
	//	if(cj[pl]==0)cout<<pl<<" "<<cj[pl]<<endl;
		return cj[pl];
	}
	int re=1;
	for(int i=0;i<ma[pl].size();i++){
		(re*=dfs(ma[pl][i]))%=mo;
	//	if(pl==0)cout<<ma[pl][i]<<" "<<dfs(ma[pl][i])<<endl;
	}
	return cj[pl]=re;
}
signed main(){
//	freopen("P7077_1.in","r",stdin);
	memset(cj,-1,sizeof cj);
	scanf("%lld",&n);
	for(int i=1;i<=n;i++)
		scanf("%lld",&a[i]);
	
	scanf("%lld",&m);
	for(int i=1;i<=m;i++){
		int op,xx,yy;
		scanf("%lld",&op);
		dy[i]=op;
		if(op==1){
			scanf("%lld%lld",&xx,&yy);
			h1[i].p=xx;
			h1[i].v=yy;
		}
		if(op==2){
			scanf("%lld",&xx);
			h2[i]=xx;
		}
		if(op==3){
			scanf("%lld",&xx);
			h3[i].sum=xx;
			for(int j=1;j<=xx;j++)
				scanf("%lld",&yy),
				h3[i].q.push_back(yy),
				ma[i].push_back(yy),
				deep[yy]++;
		}
	}
	for(int i=1;i<=m;i++){
		if(dy[i]==1)cj[i]=1;
		if(dy[i]==2)cj[i]=h2[i];
	//	if(cj[i]==0)cout<<i<<"*"<<endl;
	}
	dy[0]=3;
	int Q;
	scanf("%lld",&Q);
	for(int i=1;i<=Q;i++){
		int xx;
		scanf("%lld",&xx);
		q[i]=xx;
		ma[0].push_back(xx);
		deep[xx]++;
	}
	for(int i=0;i<=m;i++)
		if(dy[i]==3)dfs(i);//,cout<<cj[i]<<"("<<endl
//	cout<<cj[0]<<"*"<<endl;
	sum[0]=1;
	int now=1;
	queue<int>q;
	q.push(0);
	while(q.size()){
		int x=q.front();
	//	cout<<q.size()<<":";
		q.pop();
		now=1;
		//cout<<x<<"*"<<endl;
		for(int i=ma[x].size();i>=1;i--){
			int to=ma[x][i-1];
			deep[to]--;
			if(deep[to]==0)q.push(to);
			(sum[to]+=sum[x]*now%mo)%=mo;
			(now*=cj[to])%=mo;
		//	cout<<now<<" "<<cj[to]<<endl;
		}
	//	
	}
//	cout<<sum[6]<<endl;
//	cout<<now<<endl;
//	for(int i=1;i<=m;i++)
//		cout<<cj[i]<<" ";,cout<<cj[0]<<endl;

	
	for(int i=1;i<=n;i++)
		(a[i]*=cj[0])%=mo;
	for(int i=1;i<=m;i++)
		if(dy[i]==1){
			(a[h1[i].p]+=sum[i]*h1[i].v%mo)%=mo;
		}
	for(int i=1;i<=n;i++)
		cout<<a[i]<<" ";
	puts("");
}

wa的蛮离谱的

测评这里 大佬救命

2021/7/15 08:21
加载中...