蒟蒻不会算线段树的空间,请求大佬指导
查看原帖
蒟蒻不会算线段树的空间,请求大佬指导
118308
SegmentTree楼主2020/9/14 17:13
#pragma GCC optimize("Ofast","-funroll-loops","-fdelete-null-pointer-checks")
#pragma GCC target("ssse3","sse3","sse2","sse","avx2","avx")
#pragma GCC optimize(3,"Ofast","inline")

#include<bits/stdc++.h>
#define gc() getchar()
//#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
using namespace std;
const int N = 5e5 + 10;
const int mod = 998244353;
const int Log = 100;
struct edge
{
	int to, nxt;
}e[N << 1];
int head[N], tot = 0;
int get[N], dep[N];
int n, m;
int lo[Log * N], ro[Log * N], sum[Log * N], tag[N * Log], root[N], eid = 0;
int st[Log * N], top = 0;
char buf[1 << 21], *p1 = buf,*p2=buf;
inline int read() 
{
  int x = 0, f = 1;
  char c = gc();
  while (!isdigit(c)) 
	{
    if (c == '-') f = -1;
    c = gc();
  }
  while (isdigit(c)) x = x * 10 + (c ^ 48), c = gc();
  return x * f;
}
void addedge(int u, int v)
{
	e[++tot] = edge{v, head[u]};
	head[u] = tot;
	return;
}

void add(int &x, int y)
{
	if((x += y) >= mod && (x -= mod));
	return;
}
void del(int x)
{
//	cout << x << endl;
	lo[x] = ro[x] = sum[x] = 0, tag[x] = 1;
	st[++top] = x;
	return;
}
int ned()
{
	return top > 0 ? st[top--] : (tag[++eid] = 1, eid);
}
void dd(int &p, int v)
{
	if(!p) p = ned();
	sum[p] = 1ll * sum[p] * v % mod;
	tag[p] = 1ll * tag[p] * v % mod;
	return;
}

void pushd(int p)
{
	if(tag[p] == 1) return;
	dd(lo[p], tag[p]);
	dd(ro[p], tag[p]);
	tag[p] = 1;
	return;
}

void upd(int p)
{
	sum[p] = (sum[lo[p]] + sum[ro[p]]) % mod;
	return;
}

void modify(int &p, int l, int r, int x, int v)
{
	if(!p) p = ned();
	if(l == r) return void((add(sum[p], v)));
	pushd(p);
	int mid = (l + r) / 2;
	if(x <= mid) modify(lo[p], l, mid, x, v);
	else modify(ro[p], mid + 1, r, x, v);
	upd(p);
	return;
}
int query(int p, int l, int r, int x)
{
	if(!p) return 0;
	if(l == r) return sum[p];
	pushd(p);
	int mid = (l + r) / 2;
	if(x <= mid) return query(lo[p], l, mid, x);
	return query(ro[p], mid + 1, r, x);
}
int reduce(int x)
{
	return x >= mod ? x - mod : x;
}
int merge(int x, int y, int l, int r, int &sumx, int &sumy)
{
	if(!x && !y) return 0;
	if(!x || !y || l == r)
	{
		if(l == r && x && y) 
		{
			add(sumx, sum[x]);
			sum[x] = (reduce(1ll * sum[x] * sumy % mod + 1ll * sum[y] * sumx % mod));
			add(sumy, sum[y]);
			del(y);
			return x;
		}		
		else 
		{
			if(!y) return add(sumx, sum[x]), dd(x, sumy), x;
			return add(sumy, sum[y]), dd(y, sumx), y;
		}
	}
	pushd(x), pushd(y);
	int mid = (l + r) / 2;
	ro[x] = merge(ro[x], ro[y], mid + 1, r, sumx, sumy);
	lo[x] = merge(lo[x], lo[y], l, mid, sumx, sumy);
	upd(x);
	del(y);
	return x;
}
void init(int u, int fa)
{
	dep[u] = dep[fa] + 1;
	for(int i = head[u]; i; i = e[i].nxt)
	{
		int v = e[i].to;
		if(v == fa) continue;
		init(v, u);
	}
	return; 
}
void dfs(int u, int fa)
{
	if(get[u] != -1) modify(root[u], 0, n - 1, get[u], mod - 1);
	modify(root[u], 0, n - 1, dep[u], 1);	
	for(int i = head[u]; i; i = e[i].nxt)
	{
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v, u);
		int x = 0, y = 0;		
		root[u] = merge(root[u], root[v], 0, n - 1, x, y);
	}
	int t = query(root[u], 0, n - 1, dep[u] + 1), p = query(root[u], 0, n - 1, dep[u]);
	modify(root[u], 0, n - 1, dep[u] + 1, (mod - t) % mod);
	modify(root[u], 0, n - 1, dep[u], (t * 2 + p) % mod);
	return; 
}
int main()
{
	n = read();
	for(int i = 1; i <= n; i++) get[i] = -1;
	int u, v;
	for(int i = 1; i < n; i++)
	{
		u = read(), v = read();
		addedge(u, v);
		addedge(v, u);
	}
	dep[0] = -1;
	init(1, 0);
	m = read();
	for(int i = 1; i <= m; i++)
	{
		u = read(), v = read();
		get[v] = max(get[v], dep[u]);
	}
	dfs(1, 0);
	printf("%d\n", 1ll * query(root[1], 0, n - 1, 0) * ((mod + 1) / 2) % mod);
	return 0;
}
2020/9/14 17:13
加载中...