这道题我想了一整天才A掉
于是我想参考一下dalao们的做法
如下:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5e5;
vector<int> adj[N];
int a[N], par[N], n;
map<int, vector<int>> v, times;
int euler[N * 2 - 1], tin[N], tout[N], c = 0;
set<pair<int, int>> g;
int dp[N], ans;
void dfs(int v, int p = -1) {
par[v] = p;
tin[v] = c;
euler[c++] = v;
for (int i : adj[v]) {
if (i == p)
continue;
dfs(i, v);
euler[c++] = v;
}
tout[v] = c - 1;
}
void examine(int v) {
int sum = 0;
for (int i : adj[v]) {
if (i == par[v])
continue;
int count = upper_bound(times[a[v]].begin(), times[a[v]].end(), tout[i]) - lower_bound(times[a[v]].begin(), times[a[v]].end(), tin[i]);
if (count > 0)
g.insert({v, i});
sum += count;
}
sum = times[a[v]].size() - sum - 1;
if (sum)
g.insert({v, par[v]});
}
int setup(int v) {
for (int i : adj[v]) {
if (i != par[v])
dp[v] += setup(i);
}
return dp[v] + g.count({v, par[v]});
}
void reroot(int v) {
if (dp[v] == g.size())
ans++;
for (int i : adj[v]) {
if (i == par[v])
continue;
dp[v] -= dp[i];
dp[v] -= g.count({i, v});
dp[i] += dp[v];
dp[i] += g.count({v, i});
reroot(i);
dp[i] -= g.count({v, i});
dp[i] -= dp[v];
dp[v] += g.count({i, v});
dp[v] += dp[i];
}
}
int solveTestCase() {
cin >> n;
for (int i = 0; i < n; i++)
cin >> a[i];
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(0);
for (int i = 0; i < n; i++)
v[a[i]].push_back(i);
for (auto i : v) {
if (i.second.size() == 1)
continue;
for (int j : i.second)
times[i.first].push_back(tin[j]);
sort(times[i.first].begin(), times[i.first].end());
for (int j : i.second)
examine(j);
}
setup(0);
reroot(0);
cout << ans;
}
main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int t = 1;
//cin >> t;
while (t--)
solveTestCase();
}
我发现我看不懂Orz
顺带我用的是线段树合并,附上我的代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<map>
#define mod 1000000007
using namespace std;
int n,p,h[200001],a[200001],tot,dft,s[200001],dfn[200001],no[200001],num[200001];
map<int,int>m;
int root[200001],cnt;
struct pp
{
int to,ne;
}t[400000];
void add(int x,int y)
{
t[++p].to=y;
t[p].ne=h[x];
h[x]=p;
}
struct tree1
{
int tag,sum;
}tt[1000001];
struct tree
{
int l,r,num;
}tr[5000001];
void change(int l,int r,int rt,int L,int R)
{
if(L<=l&&r<=R)
{
tt[rt].tag=1;
tt[rt].sum=0;
return;
}
if(tt[rt].tag)
{
tt[rt].tag=0;
tt[rt<<1].tag=tt[rt<<1|1].tag=1;
tt[rt<<1].sum=tt[rt<<1|1].sum=0;
}
int mid=(l+r)>>1;
if(L<=mid)
change(l,mid,rt<<1,L,R);
if(R>mid)
change(mid+1,r,rt<<1|1,L,R);
tt[rt].sum=tt[rt<<1].sum+tt[rt<<1|1].sum;
}
void insert(int l,int r,int &rt,int x)
{
rt=++cnt;
if(l==r)
{
tr[rt].num=1;
return;
}
int mid=(l+r)>>1;
if(x<=mid)
insert(l,mid,tr[rt].l,x);
else
insert(mid+1,r,tr[rt].r,x);
}
void merge(int l,int r,int fa,int rt)
{
if(l==r)
{
tr[rt].num+=tr[fa].num;
return;
}
int mid=(l+r)>>1;
if(tr[rt].l&&tr[fa].l)
{
merge(l,mid,tr[fa].l,tr[rt].l);
}
else
tr[rt].l=tr[rt].l|tr[fa].l;
if(tr[rt].r&&tr[fa].r)
{
merge(mid+1,r,tr[fa].r,tr[rt].r);
}
else
tr[rt].r=tr[rt].r|tr[fa].r;
}
int ask(int l,int r,int rt,int x)
{
if(l==r)
return tr[rt].num;
int mid=(l+r)>>1;
if(x<=mid)
return ask(l,mid,tr[rt].l,x);
else
return ask(mid+1,r,tr[rt].r,x);
}
void dfs(int x,int fa)
{
dfn[x]=++dft;
s[x]=1;
insert(1,tot,root[x],a[x]);
for(int i=h[x];i;i=t[i].ne)
{
if(t[i].to!=fa)
{
dfs(t[i].to,x);
s[x]+=s[t[i].to];
int k=ask(1,tot,root[t[i].to],a[x]);
if(k)
{
if(k+1==num[a[x]])
{
change(1,n,1,1,dfn[t[i].to]-1);
if(dfn[t[i].to]+s[t[i].to]<=n)
change(1,n,1,dfn[t[i].to]+s[t[i].to],n);
}
else
{
change(1,n,1,1,n);
}
}
merge(1,tot,root[t[i].to],root[x]);
}
}
int k=ask(1,tot,root[x],a[x]);
if(k==1&&num[a[x]]!=1)
{
change(1,n,1,dfn[x],dfn[x]+s[x]-1);
}
}
void build(int l,int r,int rt)
{
if(l==r)
{
tt[rt].sum=1;
return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
tt[rt].sum=tt[rt<<1].sum+tt[rt<<1|1].sum;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
if(!m[x])
m[x]=++tot,no[tot]=x;
a[i]=m[x];
num[a[i]]++;
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
build(1,n,1);
dfs(1,0);
printf("%d",tt[1].sum);
return 0;
}
我只想问问那些dalao们的代码是什么意思啊?
(我只是通过变量和函数名严谨推理看出来了dfs和欧拉序)