思路:树链剖分+线段树+差分
按照题解的思路打的,找不出错误在哪里 qwq
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
#define MAXN 50005
#define LL long long
#define MOD 201314
vector <int> E[MAXN];
struct Node{
int l, r;
int sum, lzy;
}tr[MAXN << 3];
struct Q{
int r, z, op, ans;
int ID;
}que[MAXN * 2];
int rec[MAXN];
int fa[MAXN], dep[MAXN], siz[MAXN], w[MAXN];
int son[MAXN], top[MAXN], id[MAXN], nw[MAXN];
int n, m, cnt = 0;
void dfs1(int fath, int u)
{
fa[u] = fath, dep[u] = dep[ fa[u] ] + 1, siz[u] = 1;
for(int v : E[u])
{
if(v == fa[u])
continue;
dfs1(u, v);
siz[u] += siz[v];
if(siz[v] > siz[ son[u] ])
son[u] = v;
}
}
void dfs2(int tp, int u)
{
top[u] = tp, id[u] = ++ cnt, nw[ cnt ] = w[u];
if(son[u] == 0)
return ;
dfs2(tp, son[u]);
for(int v : E[u])
{
if(v == fa[u] || v == son[u])
continue;
dfs2(v, v);
}
}
void pushup(int op)
{
tr[op].sum = tr[op*2].sum + tr[op*2+1].sum;
}
void pushdown(int op);
void Build(int op, int L, int R)
{
int mid = (L + R) >> 1;
tr[op].l = L, tr[op].r = R;
tr[op].sum = 0, tr[op].lzy = 0;
if(L == R)
return ;
Build(op*2, L, mid);
Build(op*2+1, mid+1, R);
pushup(op);
}
int query(int op, int L, int R)
{
int mid = (tr[op].l + tr[op].r) >> 1;
if(tr[op].l == L && tr[op].r == R)
return tr[op].sum % MOD;
pushdown(op);
if(R <= mid)
return query(op*2, L, R);
else if(L > mid)
return query(op*2+1, L, R);
else
return (query(op*2, L, mid) + query(op*2+1, mid+1, R)) % MOD;
}
void add(int op, int L, int R, int val)
{
int mid = (tr[op].l + tr[op].r) >> 1;
if(tr[op].l == L && tr[op].r == R)
{
tr[op].sum = (tr[op].sum + (R-L+1) * val) % MOD;
tr[op].lzy = (tr[op].lzy + val) % MOD;
return ;
}
pushdown(op);
if(R <= mid)
add(op*2, L, R, val);
else if(L > mid)
add(op*2+1, L, R, val);
else
{
add(op*2, L, mid, val);
add(op*2+1, mid+1, R, val);
}
pushup(op);
}
void pushdown(int op)
{
if(tr[op].lzy == 0)
return ;
add(op*2, tr[op*2].l, tr[op*2].r, tr[op].lzy);
add(op*2+1, tr[op*2+1].l, tr[op*2+1].r, tr[op].lzy);
tr[op].lzy = 0;
return ;
}
void add_path(int u, int v, int val)
{
while(top[u] != top[v])
{
if(dep[ top[u] ] < dep[ top[v] ])
swap(u, v);
add(1, id[ top[u] ], id[u], val);
u = fa[top[u]];
}
if(dep[u] < dep[v])
swap(u, v);
add(1, id[v], id[u], val);
}
int query_path(int u, int v)
{
int ans = 0;
while(top[u] != top[v])
{
if(dep[ top[u] ] < dep[ top[v] ])
swap(u, v);
ans = ( ans + query(1, id[ top[u] ], id[u]) ) % MOD;
u = fa[top[u]];
}
if(dep[u] < dep[v])
swap(u, v);
return ( ans + query(1, id[v], id[u]) ) % MOD;
}
bool cmp(Q a, Q b)
{
return a.r < b.r;
}
int main()
{
cin >> n >> m;
for(int i = 2; i <= n; i ++)
{
int f;
cin >> f; f ++;
E[i].push_back(f);
E[f].push_back(i);
}
for(int i = 1; i <= m; i ++)
{
int L, R, z;
cin >> L >> R >> z;
L ++, R ++;
que[i].r = L-1, que[i].z = z,
que[i].op = 1, que[i].ID = i;
que[i+m].r = R, que[i+m].z = z,
que[i+m].op = 2, que[i+m].ID = i;
}
dfs1(0, 1);
dfs2(1, 1);
Build(1, 1, n);
sort(que+1, que+1+m+m, cmp);
int ind = 1;
for(int i = 1; i <= n && ind <= 2*m; i ++)
{
add_path(1, i, 1);
while(que[ind].r == i)
{
que[ind].ans = query_path(1, que[ind].z);
ind ++;
}
}
for(int i = 1; i <= 2*m; i ++)
{
if(que[i].op == 1)
rec[ que[i].ID ] -= que[i].ans;
else
rec[ que[i].ID ] += que[i].ans;
}
for(int i = 1; i <= m; i ++)
cout << (rec[i] + MOD) % MOD << '\n';
return 0;
}