90分,开O2后第四个点WA,求助...
#include <bits/stdc++.h>
#define int long long
#define maxn 1000005
#define re register
using namespace std;
int n, p[maxn], head[maxn], tot;
int depth[maxn], fa[maxn], tag[maxn], ans[maxn], sz[maxn], son[maxn], top[maxn], cnt, id[maxn];
struct edge {
int v, pre;
}e[maxn << 2];
inline int read() {
int w = 1, q = 0;
char ch = ' ';
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') q = q * 10 + ch - '0', ch = getchar();
return w * q;
}
inline void write(int x) {
if (x < 0) {
x = ~(x - 1);
putchar('-');
}
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
void add(int u, int v) {
e[++tot].v = v;
e[tot].pre = head[u];
head[u] = tot;
}
int ls(int x) {
return x << 1;
}
int rs(int x) {
return x << 1|1;
}
void push_up(int p) {
ans[p] = ans[ls(p)] + ans[rs(p)];
}
void f(int p, int l, int r, int k) {
ans[p] += (r - l + 1) * k;
tag[p] += k;
}
void push_down(int p, int l, int r) {
int mid = (l + r) >> 1;
f(ls(p), l, mid, tag[p]);
f(rs(p), mid + 1, r, tag[p]);
tag[p] = 0;
}
void update(int nl, int nr, int l, int r, int p, int k) {
if(nl <= l && r <= nr) {
ans[p] += (r - l + 1) * k;
tag[p] += k;
return;
}
int mid = (l + r) >> 1;
push_down(p, l, r);
if(nl <= mid) update(nl, nr, l, mid, ls(p), k);
if(nr > mid) update(nl, nr, mid + 1, r, rs(p), k);
push_up(p);
}
int query(int nl, int nr, int l, int r, int p) {
int res = 0;
if(nl <= l && r <= nr) return ans[p];
int mid = (l + r) >> 1;
push_down(p, l, r);
if(nl <= mid) res += query(nl, nr, l, mid, ls(p));
if(nr > mid) res += query(nl, nr, mid + 1, r, rs(p));
return res;
}
void dfs1(int now, int fath) {
fa[now] = fath;
sz[now] = 1;
depth[now] = depth[fath] + 1;
int t = -1;
for(int i = head[now];i;i = e[i].pre) {
int v = e[i].v;
if(v != fath) {
dfs1(v, now);
sz[now] += sz[v];
if(sz[v] > t) {
t = sz[v];
son[now] = v;
}
}
}
}
void dfs2(int now, int tp) {
top[now] = tp;
id[now] = ++cnt;
if(!son[now]) return;
dfs2(son[now], tp);
for(re int i = head[now];i;i = e[i].pre) {
int v = e[i].v;
if(v != fa[now] && v != son[now]) {
dfs2(v, v);
}
}
}
void up(int x, int y) {
while(top[x] != top[y]) {
if(depth[top[x]] < depth[top[y]]) swap(x, y);
update(id[top[x]], id[x], 1, n, 1, 1);
x = fa[top[x]];
}
if(depth[x] > depth[y]) swap(x, y);
update(id[x], id[y], 1, n, 1, 1);
}
signed main() {
n = read();
for(re int i = 1;i <= n;++i) p[i] = read();
for(re int i = 1;i <= n - 1;++i) {
int u, v;
u = read(), v = read();
add(u, v);
add(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
for(re int i = 2;i <= n;++i) up(p[i - 1], p[i]);
for(re int i = 1;i <= n;++i) {
if(p[1] != i) write(query(id[i], id[i], 1, n, 1) - 1);
else write(query(id[i], id[i], 1, n, 1));
puts("");
}
}