本地测静态空间大概是 400MB 左右,交上去 MLE 了,动态空间开销都是线性级别,求问为啥这东西会 M。
代码:
#include <bits/stdc++.h>
#define ll long long
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define pii pair<int, int>
#define mp make_pair
#define pb push_back
#define ld lower_bound
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define drep(i, a, b) for (int i = a; i >= b; i--)
#define ud upper_bound
#define N 1000005
#define M 100005
#define mem(s) memset(s, 0, sizeof(s))
#define fi first
#define se second
#define ull unsigned long long
using namespace std;
bool mbg;
const int inf = 1e9;
pii cl = mp(inf, 0);
vector <int> G[N], Qx[N], Qy[N];
int dfn[N], to[N], fa[N], R[N], col[N], X[M], Y[M], siz[N], n, m, x, tot;
ll ans;
inline void dfs(int u) {
dfn[u] = ++tot; to[tot] = u; siz[u] = 1;
for (int v : G[u]) dfs(v), siz[u] += siz[v];
R[u] = tot;
}
struct node {
pii mx, sec;
node () { mx = sec = cl; }
node (pii a, pii b = cl) { mx = a, sec = b; }
inline void ins(pii s) {
if (mx.fi > s.fi) swap(mx, s);
if (s.se != mx.se && s.fi < sec.fi) sec = s;
}
inline friend node operator + (node a, node b) {
if (b.mx.fi < a.mx.fi) swap(a, b);
a.sec = min(a.sec, a.mx.se == b.mx.se ? b.sec : b.mx);
return a;
}
inline friend node operator + (node a, int k) {
a.mx.fi += k; a.sec.fi += k;
return a;
}
} t[M], mnx[N], mny[N];
void solve1() {
node s;
rep (i, 1, m) s = s + node(mp(siz[X[i]] + siz[Y[i]], col[i]));
rep (i, 1, m) t[i] = t[i] + (s + (siz[X[i]] + siz[Y[i]]));
}
/**/
inline void dfs1(int u) {
for (int v : G[u]) dfs1(v);
for (int i : Qx[u]) mnx[u] = mnx[u] + node(mp(siz[Y[i]] - siz[X[i]], col[i]));
for (int i : Qy[u]) mny[u] = mny[u] + node(mp(siz[X[i]] - siz[Y[i]], col[i]));
for (int v : G[u]) mnx[u] = mnx[u] + mnx[v], mny[u] = mny[u] + mny[v];
for (int i : Qx[u]) t[i] = t[i] + (mnx[u] + (siz[X[i]] + siz[Y[i]]));
for (int i : Qy[u]) t[i] = t[i] + (mny[u] + (siz[X[i]] + siz[Y[i]]));
}
void solve2() {
rep (i, 1, n) mnx[i] = mny[i] = node();
dfs1(1);
}
/**/
inline void dfs2(int u, node mnx, node mny) {
for (int i : Qx[u]) mnx = mnx + node(mp(siz[X[i]] + siz[Y[i]], col[i]));
for (int i : Qx[u]) t[i] = t[i] + (mnx + (siz[Y[i]] - siz[X[i]]));
for (int i : Qy[u]) mny = mny + node(mp(siz[X[i]] + siz[Y[i]], col[i]));
for (int i : Qy[u]) t[i] = t[i] + (mny + (siz[X[i]] - siz[Y[i]]));
for (int v : G[u]) dfs2(v, mnx, mny);
}
void solve3() {
dfs2(1, node(), node());
}
/**/
struct SGT {
int tp, ls[M * 25], rs[M * 25], tot;
node seg[M * 25];
inline int newnode() {
return ++tot;
}
inline void upd(int &p, int x, node c, int l = 1, int r = n) {
if (l > x || r < x) return ;
if (!p) p = newnode(); seg[p] = seg[p] + c;
if (l == r) return ;
int mid = (l + r) >> 1;
upd(ls[p], x, c, l, mid); upd(rs[p], x, c, mid + 1, r);
}
inline node qry(int p, int ql, int qr, int l = 1, int r = n) {
if (l > qr || r < ql || ql > qr) return node();
if (ql <= l && r <= qr) return seg[p];
int mid = (l + r) >> 1;
return qry(ls[p], ql, qr, l, mid) + qry(rs[p], ql, qr, mid + 1, r);
}
inline void merge(int &p, int q, int l = 1, int r = n) {
if (!p || !q) return (void)(p |= q);
seg[p] = seg[p] + seg[q];
if (l == r) return ;
int mid = (l + r) >> 1;
merge(ls[p], ls[q], l, mid); merge(rs[p], rs[q], mid + 1, r);
}
void init() {
rep (i, 1, tot) seg[i] = node(), ls[i] = rs[i] = 0;
tot = 0;
}
} T;
int rt[N];
inline bool cmp(int x, int y) {
return siz[x] > siz[y];
}
inline void dfs3(int u) {
for (int i : Qx[u]) T.upd(rt[u], dfn[Y[i]], node(mp(-siz[X[i]] - siz[Y[i]], col[i])));
for (int v : G[u]) dfs3(v), T.merge(rt[u], rt[v]);
for (int i : Qx[u]) t[i] = t[i] + (T.qry(rt[u], dfn[Y[i]], R[Y[i]]) + (siz[X[i]] + siz[Y[i]]));
}
void solve4() { T.init(); rep (i, 1, n) rt[i] = 0; dfs3(1); }
/**/
const int P = (1 << 21) + 5;
pair<int, node> stk[M * 50];
struct zkw {
node s[P];
int m = 1, tp = 0;
void build(int n) { while (m <= n) m <<= 1; }
inline void back(int top) {
while (tp > top) s[stk[tp].fi] = stk[tp].se, tp--;
}
inline void upd(int x, node c) {
x += m;
for (; x; x >>= 1) stk[++tp] = mp(x, s[x]), s[x] = s[x] + c;
}
inline node qry(int l, int r) {
node ans;
for (l += m - 1, r += m + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (~l & 1) ans = ans + s[l ^ 1];
if (r & 1) ans = ans + s[r ^ 1];
}
return ans;
}
} B;
/**/
inline void dfs4(int u) {
int tp = B.tp;
for (int i : Qx[u]) B.upd(dfn[Y[i]], node(mp(siz[X[i]] - siz[Y[i]], col[i])));
for (int i : Qx[u]) t[i] = t[i] + (B.qry(dfn[Y[i]], R[Y[i]]) + (siz[Y[i]] - siz[X[i]]));
for (int v : G[u]) dfs4(v); B.back(tp);
}
inline void dfs5(int u) {
int tp = B.tp;
for (int i : Qy[u]) B.upd(dfn[X[i]], node(mp(siz[Y[i]] - siz[X[i]], col[i])));
for (int i : Qy[u]) t[i] = t[i] + (B.qry(dfn[X[i]], R[X[i]]) + (siz[X[i]] - siz[Y[i]]));
for (int v : G[u]) dfs5(v); B.back(tp);
}
void solve5() { B.build(n); dfs4(1); dfs5(1); }
/**/
struct SGT2 {
node t[N << 2];
int tp;
inline void back(int top) {
while (tp > top) t[stk[tp].fi] = stk[tp].se, tp--;
}
inline void upd(int p, int ql, int qr, node c, int l = 1, int r = n) {
if (ql > r || qr < l) return ;
if (ql <= l && r <= qr) return (void)(stk[++tp] = mp(p, t[p]), t[p] = t[p] + c);
int mid = (l + r) >> 1;
upd(p << 1, ql, qr, c, l, mid); upd(p << 1 | 1, ql, qr, c, mid + 1, r);
}
inline node qry(int p, int x, int l = 1, int r = n) {
if (l > x || r < x) return node();
if (l == r) return t[p];
int mid = (l + r) >> 1;
return qry(p << 1, x, l, mid) + qry(p << 1 | 1, x, mid + 1, r) + t[p];
}
} H;
inline void dfs6(int u) {
int tp = H.tp;
for (int i : Qx[u]) H.upd(1, dfn[Y[i]], R[Y[i]], node(mp(siz[X[i]] + siz[Y[i]], col[i])));
for (int i : Qx[u]) t[i] = t[i] + (H.qry(1, dfn[Y[i]]) + (-siz[X[i]] - siz[Y[i]]));
for (int v : G[u]) dfs6(v); H.back(tp);
}
void solve6() {
dfs6(1);
}
/**/
inline int get(int u) {
return (col[u] == u ? u : col[u] = get(col[u]));
}
struct edge {
int u, v, w;
};
vector <edge> S;
bool boruvka() {
rep (i, 1, m) t[i] = node();
solve1();
solve2();
solve3();
solve4();
solve5();
solve6();
S.clear();
rep (i, 1, m) if (i ^ col[i]) t[col[i]] = t[col[i]] + t[i];
rep (i, 1, m) {
if (col[i] == i) {
pii s = (t[i].mx.se == i ? t[i].sec : t[i].mx);
S.pb({i, s.se, s.fi});
}
}
for (edge E : S) {
if (get(E.u) != get(E.v)) col[get(E.u)] = get(E.v), ans += E.w;
}
int cnt = 0;
rep (i, 1, m) cnt += (i == get(i));
return (cnt > 1);
}
void solve() {
cin >> n >> m;
rep (i, 2, n) cin >> x, G[x].pb(i);
dfs(1);
rep (i, 1, m) {
cin >> X[i] >> Y[i]; col[i] = i;
Qx[X[i]].pb(i); Qy[Y[i]].pb(i);
}
while (boruvka());
cout << ans;
}
bool med;
signed main() {
// IOS;
std::cerr << '\n' << (&mbg - &med) * (1.l / (1<<20)) << '\n';
int T; T = 1;
while (T--) solve();
}