求助 Codeforces #695(Div.2)E
  • 板块题目总版
  • 楼主Gary88
  • 当前回复4
  • 已保存回复4
  • 发布时间2021/1/14 16:27
  • 上次更新2023/11/5 04:50:54
查看原帖
求助 Codeforces #695(Div.2)E
104963
Gary88楼主2021/1/14 16:27

这道题我想了一整天才A掉

于是我想参考一下dalao们的做法

如下:

#include <bits/stdc++.h>
using namespace std;
#define int long long

const int N = 5e5;
vector<int> adj[N];
int a[N], par[N], n;

map<int, vector<int>> v, times;
int euler[N * 2 - 1], tin[N], tout[N], c = 0;

set<pair<int, int>> g;
int dp[N], ans;

void dfs(int v, int p = -1) {
    par[v] = p;
    tin[v] = c;
    euler[c++] = v;

    for (int i : adj[v]) {
        if (i == p)
            continue;
        dfs(i, v);
        euler[c++] = v;
    }

    tout[v] = c - 1;
}

void examine(int v) {
    int sum = 0;

    for (int i : adj[v]) {
        if (i == par[v])
            continue;

        int count = upper_bound(times[a[v]].begin(), times[a[v]].end(), tout[i]) - lower_bound(times[a[v]].begin(), times[a[v]].end(), tin[i]);
        if (count > 0)
            g.insert({v, i});
        sum += count;
    }

    sum = times[a[v]].size() - sum - 1;
    if (sum)
        g.insert({v, par[v]});
}

int setup(int v) {
    for (int i : adj[v]) {
        if (i != par[v])
            dp[v] += setup(i);
    }
    return dp[v] + g.count({v, par[v]});
}

void reroot(int v) {
    if (dp[v] == g.size())
        ans++;

    for (int i : adj[v]) {
        if (i == par[v])
            continue;

        dp[v] -= dp[i];
        dp[v] -= g.count({i, v});
        dp[i] += dp[v];
        dp[i] += g.count({v, i});

        reroot(i);

        dp[i] -= g.count({v, i});
        dp[i] -= dp[v];
        dp[v] += g.count({i, v});
        dp[v] += dp[i];
    }
}

int solveTestCase() {
    cin >> n;
    for (int i = 0; i < n; i++)
        cin >> a[i];

    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs(0);
    for (int i = 0; i < n; i++)
        v[a[i]].push_back(i);

    for (auto i : v) {
        if (i.second.size() == 1)
            continue;

        for (int j : i.second)
            times[i.first].push_back(tin[j]);
        sort(times[i.first].begin(), times[i.first].end());
        for (int j : i.second)
            examine(j);
    }

    setup(0);
    reroot(0);

    cout << ans;
}

 main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int t = 1;
    //cin >> t;
    while (t--)
        solveTestCase();
}

我发现我看不懂Orz

顺带我用的是线段树合并,附上我的代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<map>
#define mod 1000000007
using namespace std;
int n,p,h[200001],a[200001],tot,dft,s[200001],dfn[200001],no[200001],num[200001];
map<int,int>m;
int root[200001],cnt;
struct pp
{
	int to,ne;
}t[400000];
void add(int x,int y)
{
	t[++p].to=y;
	t[p].ne=h[x];
	h[x]=p;
}
struct tree1
{
	int tag,sum;
}tt[1000001];
struct tree
{
	int l,r,num;
}tr[5000001];
void change(int l,int r,int rt,int L,int R)
{
	if(L<=l&&r<=R)
	{
		tt[rt].tag=1;
		tt[rt].sum=0;
		return;
	}
	if(tt[rt].tag)
	{
		tt[rt].tag=0;
		tt[rt<<1].tag=tt[rt<<1|1].tag=1;
		tt[rt<<1].sum=tt[rt<<1|1].sum=0;
	}
	int mid=(l+r)>>1;
	if(L<=mid)
	change(l,mid,rt<<1,L,R);
	if(R>mid)
	change(mid+1,r,rt<<1|1,L,R);
	tt[rt].sum=tt[rt<<1].sum+tt[rt<<1|1].sum;
}
void insert(int l,int r,int &rt,int x)
{
	rt=++cnt;
	if(l==r)
	{
		tr[rt].num=1;
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid)
	insert(l,mid,tr[rt].l,x);
	else
	insert(mid+1,r,tr[rt].r,x);
}
void merge(int l,int r,int fa,int rt)
{
	if(l==r)
	{
		tr[rt].num+=tr[fa].num;
		return;
	}
	int mid=(l+r)>>1;
	if(tr[rt].l&&tr[fa].l)
	{
		merge(l,mid,tr[fa].l,tr[rt].l);
	}
	else
	tr[rt].l=tr[rt].l|tr[fa].l;
	if(tr[rt].r&&tr[fa].r)
	{
		merge(mid+1,r,tr[fa].r,tr[rt].r);
	}
	else
	tr[rt].r=tr[rt].r|tr[fa].r;
}
int ask(int l,int r,int rt,int x)
{
	if(l==r)
	return tr[rt].num;
	int mid=(l+r)>>1;
	if(x<=mid)
	return ask(l,mid,tr[rt].l,x);
	else
	return ask(mid+1,r,tr[rt].r,x);
}
void dfs(int x,int fa)
{
	dfn[x]=++dft;
	s[x]=1;
	insert(1,tot,root[x],a[x]);
	for(int i=h[x];i;i=t[i].ne)
	{
		if(t[i].to!=fa)
		{
			dfs(t[i].to,x);
			s[x]+=s[t[i].to];
			int k=ask(1,tot,root[t[i].to],a[x]);
			if(k)
			{
				if(k+1==num[a[x]])
				{
					change(1,n,1,1,dfn[t[i].to]-1);
					if(dfn[t[i].to]+s[t[i].to]<=n)
					change(1,n,1,dfn[t[i].to]+s[t[i].to],n);
				}
				else 
				{
					change(1,n,1,1,n);
				}
			}
			merge(1,tot,root[t[i].to],root[x]);
		}
	}
	int k=ask(1,tot,root[x],a[x]);
	if(k==1&&num[a[x]]!=1)
	{
		change(1,n,1,dfn[x],dfn[x]+s[x]-1);
	}
}
void build(int l,int r,int rt)
{
	if(l==r)
	{
		tt[rt].sum=1;
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,rt<<1);
	build(mid+1,r,rt<<1|1);
	tt[rt].sum=tt[rt<<1].sum+tt[rt<<1|1].sum;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
	{
		int x;
		scanf("%d",&x);
		if(!m[x])
		m[x]=++tot,no[tot]=x;
		a[i]=m[x];
		num[a[i]]++;
	}
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x); 
	}
	build(1,n,1);
	dfs(1,0);
	printf("%d",tt[1].sum);
	return 0;
}

我只想问问那些dalao们的代码是什么意思啊?

(我只是通过变量和函数名严谨推理看出来了dfs和欧拉序)

2021/1/14 16:27
加载中...