47pts求调
查看原帖
47pts求调
649246
nothing__楼主2024/11/20 21:57
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read() {
	int x=0, w=1; char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') w=-1; ch=getchar();}
	while(isdigit(ch)) {x=x*10+(ch-'0'); ch=getchar();}
	return x*w;
}
const int N=2e5+10;
const int inf=0x3f3f3f3f;
int n, q; 
vector<int> g[N];
int tsp, rk[N];
struct node{int fa, siz, son, top, dep, dfn;} t[N];
void dfs1(int x, int fa) {
	t[x].fa=fa; t[x].dep=t[fa].dep+1;
	t[x].siz=1; t[x].dfn=++tsp; rk[tsp]=x; 
	for(auto y:g[x]) {
		if(y==fa) continue;
		dfs1(y, x);
		t[x].siz+=t[y].siz;
		if(t[y].siz>t[t[x].son].siz) t[x].son=y;
	}
}
void dfs2(int x, int top) {
	t[x].top=top;
	if(!t[x].son) return ;
	dfs2(t[x].son, top);
	for(auto y:g[x]) {
		if(y==t[x].fa||y==t[x].son) continue;
		dfs2(y, y);
	}
}
struct trnode{int l, r, lc, rc, c, tag;} tr[N<<2]; int trlen;
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
#define len(x) (tr[x].r-tr[x].l+1)
void pushdown(int now) {
	if(!tr[now].tag) return ;
	int tag=tr[now].tag;
	tr[lc(now)].c+=len(lc(now))*tag; tr[lc(now)].tag+=tag;
	tr[rc(now)].c+=len(rc(now))*tag; tr[rc(now)].tag+=tag;
	tr[now].tag=0; return ;
}
void build(int l, int r) {
	int now=++trlen;
	tr[now]={l, r, -1, -1, 0, 0};
	if(l==r) return ;
	int mid=(l+r)>>1;
	tr[now].lc=trlen+1, build(l, mid);
	tr[now].rc=trlen+1, build(mid+1, r);
}
void update(int now, int l, int r) {
	if(l<=tr[now].l&&r>=tr[now].r) {
		tr[now].c+=len(now); tr[now].tag++;
		return ;
	}
	pushdown(now);
	int mid=(tr[now].l+tr[now].r)>>1;
	if(l<=mid) update(lc(now), l, r);
	if(r>mid) update(rc(now), l, r);
	tr[now].c=(tr[lc(now)].c+tr[rc(now)].c);
}
void change(int x, int y) {
	while(t[x].top!=t[y].top) {
		if(t[t[x].top].dep<t[t[y].top].dep) swap(x, y);
		update(1, t[t[x].top].dfn, t[x].dfn);
		x=t[t[x].top].fa;
	}
	if(t[x].dep>t[y].dep) swap(x, y);
	update(1, t[x].dfn+1, t[y].dfn);
}
int query(int now, int l, int r) {
	if(l<=tr[now].l&&r>=tr[now].r) return tr[now].c;
	pushdown(now);
	int mid=(tr[now].l+tr[now].r)>>1;
	if(l<=mid) return query(lc(now), l, r);
	if(r>mid) return query(rc(now), l, r);
}
signed main() {
	n=read(), q=read();
	for(int i=1,x,y;i<n;i++) {
		x=read(), y=read();
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs1(1, 0); dfs2(1, 1);
	build(1, n);
	while(q--) {
		char s[2]; scanf("%s", s);
		int x=read(), y=read();
		if(s[0]=='P') change(x, y);
		else {
			if(t[y].fa==x) swap(x, y);
			printf("%lld\n", query(1, t[x].dfn, t[x].dfn));
		}
	}
	return 0;
}
2024/11/20 21:57
加载中...