20分求查错(带注释详解)
查看原帖
20分求查错(带注释详解)
173397
取啥名好楼主2020/9/17 13:57
#include<bits/stdc++.h>
using namespace std;
int n,a,b,cnt=0,jl[1000010]={0},col[1000010]={0},sum[1000010]={0},head[2000010]={0},size[1000010]={0};
int js[1100000]={0},bian[2000010]={0};
struct node{int to,next;}tu[1000010]={0};
queue<int> q[1000010];
void pre_dfs(int x,int fa){
    size[x]=1;int pre_jl=sum[col[x]],last=0;sum[col[x]]++;
    for(int i=head[x];i;i=tu[i].next){
        if(tu[i].to==fa)continue;
        pre_dfs(tu[i].to,x);
        size[x]+=size[tu[i].to];
        if(pre_jl!=sum[col[x]])last=tu[i].to;
        bian[i]=size[tu[i].to];
        if(i%2)bian[i+1]=n-size[tu[i].to];
        else bian[i-1]=n-size[tu[i].to];//bian是记录该边所连的点为根的子树大小,去的bian的大小就是子树大小,回来的就是n-子树大小。bian从1开始存,所以要这样处理。
    }//cout<<x<<" "<<pre_jl<<" "<<sum[col[x]]<<endl;
    if(sum[col[x]]==jl[col[x]]&&pre_jl==0)q[col[x]].push(x),js[x]=last;//如果本来一个点都没扫到,现在扫到的点数又等于所有该颜色点的个数,那这个点一定是一个端点
    else if(sum[col[x]]==pre_jl+1)q[col[x]].push(x),js[x]=fa;//如果下面没有该颜色的点,那这个点也是端点
    //js纪录的是通往链中间那坨的点,在后面计算答案要排除的
}
void solve_1(int x){
    queue<int> qq;int sum=0,ans=0,temp;
    for(int i=head[x];i;i=tu[i].next)
        qq.push(bian[i]),sum+=bian[i];
    while(!qq.empty()){
        temp=qq.front();qq.pop();
        sum-=temp;ans+=sum*temp;
    }cout<<(ans+(n-1))<<endl;
}//所有子树两两乘积和+n-1
void solve_2(int x,int y){
    int sum_x=1,sum_y=1;
    for(int i=head[x];i;i=tu[i].next)
        if(tu[i].to!=js[x])sum_x+=bian[i];
    for(int i=head[y];i;i=tu[i].next)
        if(tu[i].to!=js[y])sum_y+=bian[i];
    cout<<sum_x*sum_y<<endl;
}//去掉中间一坨点后,剩下两颗子树的大小乘积
void add(int u,int v){
    cnt++;tu[cnt].to=v;tu[cnt].next=head[u];head[u]=cnt;
    cnt++;tu[cnt].to=u;tu[cnt].next=head[v];head[v]=cnt;
}
int main(){
    cin>>n;for(int i=1;i<=n;i++)cin>>col[i],jl[col[i]]++;//记录每种颜色的点数
    for(int i=1;i<n;i++)cin>>a>>b,add(a,b);
    pre_dfs(1,0);for(int i=1;i<=n;i++){
       // cout<<q[i].size()<<" ";
        if(q[i].size()>2)cout<<0<<endl;
        if(q[i].size()==1)solve_1(q[i].front());
        if(q[i].size()==0)cout<<((n*(n-1))/2)<<endl;
        if(q[i].size()==2){int temp=q[i].front();q[i].pop();solve_2(temp,q[i].front());}
        //分类讨论
    }
    system("pause");return 0;
}
2020/9/17 13:57
加载中...