写的换根 dp,可是交上去一直 TLE
,本地没有问题,对拍什么的都过了,n 开到 106 也过了。之前的 O(n2) 暴力就能过,就这个超时。。。请问这是为什么啊 qwq
#include <bits/stdc++.h>
using namespace std;
#define il inline
#define re register
#define Rep(i, s, e) for (re int i = s; i <= e; ++i)
#define Dep(i, s, e) for (re int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)
const int N = 2000010;
il int read() {
int x = 0; bool f = true; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = false; c = getchar();}
while (isdigit(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? x : -x;
}
int n;
char s[N];
bool hav[N];
int head[N], tot;
struct node {int nxt, to;} edge[N << 1];
il void adde(int u, int v) {edge[++tot].nxt = head[u], edge[tot].to = v, head[u] = tot;}
int dep[N], down[N], now, downf[N], sz[N], nmax[N][3], g[N], up[N], ans = 0x3f3f3f3f, f[N], Ans[N];
il int upd(int u, int x) {
if (x > nmax[u][1]) nmax[u][2] = nmax[u][1], nmax[u][1] = x;
else if (x > nmax[u][2]) nmax[u][2] = x;
}
il int get(int u) {
return down[u] + downf[u] + 2 * sz[u];
}
void dfs_down(int u, int fa) {
dep[u] = dep[fa] + 1, sz[u] = hav[u]; int nowmax = 0;
for (re int e = head[u]; e; e = edge[e].nxt) {
int v = edge[e].to;
if (v == fa) continue;
dfs_down(v, u), down[u] += down[v] + sz[v], sz[u] += sz[v], upd(u, get(v));
}
downf[u] = nmax[u][1] > down[u] ? nmax[u][1] - down[u] : down[u] & 1;
}
void dp(int u, int fa) {
if (u != 1) {
up[u] = down[fa] - down[u] - sz[u] - sz[u] + up[fa] + sz[1], g[u] = up[u] + down[u];
int nowmax = max(get(u) == nmax[fa][1] ? nmax[fa][2] : nmax[fa][1], f[fa] + up[fa]), nowg = g[fa] - down[u] - sz[u];
f[u] = nowmax > nowg ? nowmax - nowg : nowg & 1;
f[u] += sz[1] - sz[u];
nowg = g[u], nowmax = max(nmax[u][1], up[u] + f[u]);
Ans[u] = nowmax > nowg ? nowmax - nowg : nowg & 1;
}
for (re int e = head[u]; e; e = edge[e].nxt) {
int v = edge[e].to;
if (v == fa) continue;
dp(v, u);
}
}
int main() {
file(data);
n = read(), dep[0] = -1;
scanf("%s", s + 1);
Rep(i, 1, n) hav[i] = (s[i] == '1');
Rep(i, 2, n) {
int u = read(), v = read();
adde(u, v), adde(v, u);
}
dfs_down(1, 0);
if (!downf[1]) ans = down[1] >> 1;
g[1] = down[1];
dp(1, 0);
Rep(i, 2, n) {
if (!Ans[i]) ans = min(ans, g[i] >> 1);
}
printf("%d\n", ans == 0x3f3f3f3f ? -1 : ans);
return 0;
}