关于卡常。。。
查看原帖
关于卡常。。。
99643
陈刀仔楼主2022/1/3 13:31

卡不动求帮助qwq

#include <bits/stdc++.h>
using namespace std;
inline int read(){
    int s=0,w=1,ch=getchar();
    while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
    while(isdigit(ch)){s=(s<<3)+(s<<1)+ch-48;ch=getchar();}
    return s*w;
}
const int maxn=5e5+50,mod=998244353;
int n,cnt,head[maxn],to[maxn*2],nex[maxn*2],maxd[maxn],rt[maxn],dep[maxn];
void add(int x,int y){to[++cnt]=y;nex[cnt]=head[x];head[x]=cnt;}
int tot,sum[maxn*60],tag[maxn*60],lc[maxn*60],rc[maxn*60];
void pushdown(int x){
    if(tag[x]!=1){
        if(lc[x]==0)lc[x]=++tot,tag[tot]=1;
        if(rc[x]==0)rc[x]=++tot,tag[tot]=1;
        sum[lc[x]]=1ll*sum[lc[x]]*tag[x]%mod;tag[lc[x]]=1ll*tag[lc[x]]*tag[x]%mod;
        sum[rc[x]]=1ll*sum[rc[x]]*tag[x]%mod;tag[rc[x]]=1ll*tag[rc[x]]*tag[x]%mod;
        tag[x]=1;
    }
}
void pushup(int x){sum[x]=(sum[lc[x]]+sum[rc[x]])%mod;}
int update(int x,int l,int r,int pos,int val){
    if(x==0)x=++tot,tag[x]=1;
    if(l==r){sum[x]=val;tag[x]=1;return x;}
    pushdown(x);
    int mid=(l+r)>>1;
    if(pos<=mid)lc[x]=update(lc[x],l,mid,pos,val);
    else rc[x]=update(rc[x],mid+1,r,pos,val);
    pushup(x);
    return x;
}
int query(int x,int l,int r,int L,int R){
    if(x==0)return 0;
    if(l>=L&&r<=R)return sum[x];
    pushdown(x);
    int mid=(l+r)>>1,res=0;
    if(L<=mid)res=(res+query(lc[x],l,mid,L,R))%mod;
    if(R>mid)res=(res+query(rc[x],mid+1,r,L,R))%mod;
    pushup(x);
    return res;
}
inline int merge(int x,int y,int l,int r,int &su,int &sv)
{
	if(!x&&!y)
	{
		return 0;
	}
	if(!x)
	{
		sv=1ll*(sv+sum[y])%mod,tag[y]=1ll*tag[y]*su%mod;
		return sum[y]=1ll*sum[y]*su%mod,y;
	}
	if(!y)
	{
		su=(su+sum[x])%mod,tag[x]=1ll*tag[x]*sv%mod;
		return sum[x]=1ll*sum[x]*sv%mod,x;
	}
	if(l==r)
	{
		int cu=sum[x],cv=sum[y];
		sv=(sv+cv)%mod,sum[x]=(1ll*sum[x]*sv%mod+1ll*sum[y]*su%mod)%mod;
		return su=(su+cu)%mod,x;
	}
	int mid=(l+r)>>1;
	pushdown(x),pushdown(y);
	lc[x]=merge(lc[x],lc[y],l,mid,su,sv);
	rc[x]=merge(rc[x],rc[y],mid+1,r,su,sv);
	pushup(x);
	return x;
}
void dfs(int u,int fa){
    rt[u]=update(rt[u],0,n,maxd[u],1);
    for(int i=head[u];i;i=nex[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,u);
        int tmp=query(rt[v],0,n,0,dep[u]);
        int s1=0,s2=tmp;
        rt[u]=merge(rt[u],rt[v],0,n,s1,s2);
    }
}
void getd(int u,int fa){
    dep[u]=dep[fa]+1;
    for(int i=head[u];i;i=nex[i]){
        int v=to[i];if(v==fa)continue;
        getd(v,u);
    }
}
signed main(){
    n=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    getd(1,0);
    int m=read();
    for(int i=1;i<=m;i++){
        int u=read(),v=read();
        maxd[v]=max(maxd[v],dep[u]);
    }
    dfs(1,0);
    printf("%d\n",query(rt[1],0,n,0,0));
    return 0;
}
2022/1/3 13:31
加载中...