蒟蒻用的线段树合并。赛中并没做出来,赛后有大佬告诉我要离线询问,然后离线了一下就过了。我不是很明白(或许是我线段树合并没学明白),线段树从下到上合并,为什么我合并完再查询会出错呢?
两份代码只有solve函数,和main函数处理询问的时候有个别差异。
AC代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define NDEBUG
#include<assert.h>
using namespace std;
typedef vector<int> vi;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef unsigned int ui;
inline int read() {
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x * f;
}
#define endl '\n'
#define rd read()
#define pb push_back
#define mst(a, b) memset((a), (b), sizeof(a));
#define inf 0x3f3f3f3f
#define linf 0x3f3f3f3f3f3f3f3f
#define mod ((int)1e9+7)
#define maxn (int)(2e5+5)
struct edge {int u, v; edge(int _u = 0, int _v = 0) {u = _u, v = _v;}};
vector<edge> e[maxn];
void ins(int u, int v) {e[u].push_back(edge(u, v));}
int n, m, tot;
int rt[maxn];//
int depth[maxn], fa[maxn][19], lg[maxn];
int ans[maxn];
struct node {
int id, u, d;
};
vector<node> q[maxn];
struct tree {
int ls, rs, x;
} t[maxn * 600];
void inp() {
cin >> n;
for(int i = 2; i <= n; ++i) {
int u = rd;
ins(u, i);
}
cin >> m;
}
void log2() {
for(int i = 1; i <= n; ++i)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
}
void dfs(int u, int fath) {
fa[u][0] = fath, depth[u] = depth[fath] + 1;
for(int i = 1; i <= lg[depth[u]]; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(unsigned int i = 0; i < e[u].size(); ++i)
if(e[u][i].v != fath)
dfs(e[u][i].v, u);
}
inline int lca(int x, int y) {
if(depth[x] < depth[y])
swap(x, y);
while(depth[x] > depth[y])
x = fa[x][lg[depth[x] - depth[y]] - 1];
if(x == y) return x;
for(int i = lg[depth[x]] - 1; i >= 0; --i) {
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
inline void pushup(int u) {
int ls = t[u].ls, rs = t[u].rs;
t[u].x = t[ls].x + t[rs].x;
}
void upd(int &u, int l, int r, int x) {
if(!u) u = ++tot;
if(l == r) {
++t[u].x;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) upd(t[u].ls, l, mid, x);
else upd(t[u].rs, mid + 1, r, x);
pushup(u);
}
int merge(int u, int v, int l, int r) {
if(!u || !v) return u + v;
int mid = (l + r) >> 1;
if(l == r) {
t[u].x += t[v].x;
return u;
}
t[u].ls = merge(t[u].ls, t[v].ls, l, mid);
t[u].rs = merge(t[u].rs, t[v].rs, mid + 1, r);
pushup(u);
return u;
}
int query(int u, int l, int r, int x) {
if(l == r) return t[u].x;
int mid = (l + r) >> 1;
if(x <= mid) return query(t[u].ls, l, mid, x);
else return query(t[u].rs, mid + 1, r, x);
}
void solve(int u, int fa) {
for(ui i = 0; i < e[u].size(); ++i) {
int v = e[u][i].v;
if(v == fa) continue;
solve(v, u);
rt[u] = merge(rt[u], rt[v], 1, n);
}
upd(rt[u], 1, n, depth[u]);
for(ui i = 0; i < q[u].size(); ++i)
ans[q[u][i].id] = query(rt[u], 1, n, q[u][i].d);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("D:\\Chrome Downloadings\\input.txt", "r", stdin);
freopen("D:\\Chrome Downloadings\\output.txt", "w", stdout);
#endif
inp();
tot = n;
for(int i = 1; i <= n; ++i)
rt[i] = i;
for(int i = 1; i <= m; ++i) {
int v = rd, d = rd;
q[v].pb({i, v, d + 1});
}
log2(), dfs(1, 0);
solve(1, 0);
for(int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}
错误代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define NDEBUG
#include<assert.h>
using namespace std;
typedef vector<int> vi;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef unsigned int ui;
inline int read() {
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return x * f;
}
#define endl '\n'
#define rd read()
#define pb push_back
#define mst(a, b) memset((a), (b), sizeof(a));
#define inf 0x3f3f3f3f
#define linf 0x3f3f3f3f3f3f3f3f
#define mod ((int)1e9+7)
#define maxn (int)(2e5+5)
struct edge {int u, v; edge(int _u = 0, int _v = 0) {u = _u, v = _v;}};
vector<edge> e[maxn];
void ins(int u, int v) {e[u].push_back(edge(u, v));}
int n, m, tot;
int rt[maxn];//
int depth[maxn], fa[maxn][19], lg[maxn];
vi q[maxn], ans;
struct tree {
int ls, rs, x;
} t[maxn * 600];
void inp() {
cin >> n;
for(int i = 2; i <= n; ++i) {
int u = rd;
ins(u, i);
}
cin >> m;
}
void log2() {
for(int i = 1; i <= n; ++i)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
}
void dfs(int u, int fath) {
fa[u][0] = fath, depth[u] = depth[fath] + 1;
for(int i = 1; i <= lg[depth[u]]; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(unsigned int i = 0; i < e[u].size(); ++i)
if(e[u][i].v != fath)
dfs(e[u][i].v, u);
}
inline int lca(int x, int y) {
if(depth[x] < depth[y])
swap(x, y);
while(depth[x] > depth[y])
x = fa[x][lg[depth[x] - depth[y]] - 1];
if(x == y) return x;
for(int i = lg[depth[x]] - 1; i >= 0; --i) {
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
inline void pushup(int u) {
int ls = t[u].ls, rs = t[u].rs;
t[u].x = t[ls].x + t[rs].x;
}
void upd(int &u, int l, int r, int x) {
if(!u) u = ++tot;
if(l == r) {
++t[u].x;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) upd(t[u].ls, l, mid, x);
else upd(t[u].rs, mid + 1, r, x);
pushup(u);
}
int merge(int u, int v, int l, int r) {
if(!u || !v) return u + v;
int mid = (l + r) >> 1;
if(l == r) {
t[u].x += t[v].x;
return u;
}
t[u].ls = merge(t[u].ls, t[v].ls, l, mid);
t[u].rs = merge(t[u].rs, t[v].rs, mid + 1, r);
pushup(u);
return u;
}
int query(int u, int l, int r, int x) {
if(l == r) return t[u].x;
int mid = (l + r) >> 1;
if(x <= mid) return query(t[u].ls, l, mid, x);
else return query(t[u].rs, mid + 1, r, x);
}
void solve(int u, int fa) {
for(ui i = 0; i < e[u].size(); ++i) {
int v = e[u][i].v;
if(v == fa) continue;
solve(v, u);
rt[u] = merge(rt[u], rt[v], 1, n);
}
upd(rt[u], 1, n, depth[u]);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("D:\\Chrome Downloadings\\input.txt", "r", stdin);
freopen("D:\\Chrome Downloadings\\output.txt", "w", stdout);
#endif
inp();
tot = n;
for(int i = 1; i <= n; ++i)
rt[i] = i;
log2(), dfs(1, 0);
solve(1, 0);
for(int i = 1; i <= m; ++i) {
int v = rd, d = rd;
if(d + 1 < depth[v]) puts("0");
else printf("%d\n", query(rt[v], 1, n, d + 1));
}
return 0;
}