https://www.luogu.com.cn/record/38834665
rt,有人帮忙看看为啥吗,样例过了
#include<bits/stdc++.h>
using namespace std;
#define reg register
#define ll long long
extern "C" {
namespace io {
#define BUFS 100000
static char in[BUFS], *p = in, *pp = in;
#define gc() (p == pp && (pp = (p = in) + fread(in, 1, BUFS, stdin), p == pp) ? EOF : *p++)
inline int read() {
reg int x = 0; reg char ch, f = 0;
while (!isdigit(ch = gc())) f |= ch == '-';
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = gc();
return f ? -x : x;
}
}}
#define rd io :: read
const int N = 300001, mod = 998244353;
inline int qpow(int a, int b) {
reg int res = 1;
for (; b; b >>= 1, a = (ll)a * a % mod) (b & 1) && (res = (ll)res * a % mod);
return res;
}
const int inv_ = qpow(10000, mod - 2);
int n, m, V[N], cnt, U[N], D[N], ch[N][2];
struct Node {int val, tag, ls, rs;} seg[N << 2];
int ptr, root[N];
inline void pushtag(int k, int v) {seg[k].val = (ll)seg[k].val * v % mod, seg[k].tag = (ll)seg[k].tag * v % mod;}
inline void pushup(int k) {seg[k].val = (seg[seg[k].ls].val + seg[seg[k].rs].val) % mod;}
inline void pushdown(int k) {
if (seg[k].tag == 1) return ;
seg[k].ls && (pushtag(seg[k].ls, seg[k].tag), 0);
seg[k].rs && (pushtag(seg[k].rs, seg[k].tag), 0);
seg[k].tag = 1;
}
void modify(int &k, int l, int r, int p, int v) {
if (!k) k = ++ptr, seg[k].tag = 1;
if (l == r) {seg[k].val = v; return ;}
pushdown(k);
int mid = l + r >> 1;
p <= mid ? modify(seg[k].ls, l, mid, p, v) : modify(seg[k].rs, mid + 1, r, p, v);
pushup(k);
}
int merge(int x, int y, int xl, int xr, int yl, int yr, int v) {
if (!x && !y) return 0;
pushdown(x), pushdown(y);
if (!y) {
pushtag(x, ((ll)v * yl % mod + (ll)(1 - v + mod) * yr % mod) % mod);
return x;
}
if (!x) {
pushtag(y, ((ll)v * xl % mod + (ll)(1 - v + mod) * xr % mod) % mod);
return y;
}
int Lx = seg[seg[x].ls].val, Rx = seg[seg[x].rs].val, Ly = seg[seg[y].rs].val, Ry = seg[seg[y].rs].val;
seg[x].ls = merge(seg[x].ls, seg[y].ls, xl, ((ll)xr + Rx) % mod, yl, ((ll)yr + Ry) % mod, v);
seg[x].rs = merge(seg[x].rs, seg[y].rs, ((ll)xl + Lx) % mod, xr, ((ll)yl + Ly) % mod, yr, v);
pushup(x);
return x;
}
void dfs1(int x) {
if (!ch[x][0]) {
modify(root[x], 1, cnt, V[x], 1);
return ;
}
if (!ch[x][1]) {
dfs1(ch[x][0]), root[x] = root[ch[x][0]];
return ;
}
dfs1(ch[x][0]), dfs1(ch[x][1]);
root[x] = merge(root[ch[x][0]], root[ch[x][1]], 0, 0, 0, 0, V[x]);
}
void dfs2(int k, int l, int r) {
if (l == r) {D[l] = seg[k].val; return ;}
pushdown(k);
int mid = l + r >> 1;
dfs2(seg[k].ls, l, mid), dfs2(seg[k].rs, mid + 1, r);
}
int main() {
n = rd();
for (reg int i = 1, fa; i <= n; ++i)
fa = rd(), ch[fa][0] ? ch[fa][1] = i : ch[fa][0] = i;
for (reg int i = 1; i <= n; ++i)
V[i] = rd(), ch[i][0] ? V[i] = (ll)V[i] * inv_ % mod : U[++cnt] = V[i];
sort(U + 1, U + cnt + 1);
for (reg int i = 1; i <= n; ++i)
ch[i][0] || (V[i] = lower_bound(U + 1, U + cnt + 1, V[i]) - U);
dfs1(1), dfs2(root[1], 1, cnt);
reg int ans = 0;
for (reg int i = 1; i <= cnt; ++i)
ans = ((ll)ans + (ll)i * U[i] % mod * D[i] % mod * D[i] % mod) % mod;
printf("%d", ans);
return 0;
}