萌新求助,RE On Test 10
查看原帖
萌新求助,RE On Test 10
310802
_Diu_楼主2022/1/25 16:15

rt

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+10,M=2e7,lim=1.9e7,mod=(1ll<<30)-1;
int n,m,fa[N],rt[N],tot,last,a[N];
ll sum_dep[N],sum[N],val[N],dep[N];
int son[N],sz[N],top[N],dfn[N],id;
struct edge{int v,w;};
vector<edge> g[N];
struct tree{
	int ls,rs;
	ll sum,tg;
}tr[M];
void dfs(int u){
	sz[u]=1;
	for(int i=0;i<g[u].size();i++){
		int v=g[u][i].v;
		if(v==fa[u])continue;
		dep[v]=dep[u]+g[u][i].w;
		fa[v]=u,val[v]=g[u][i].w;
		dfs(v);
		sz[u]+=sz[v];
		if(sz[v]>sz[son[u]])son[u]=v;
	}
}
void dfs2(int u,int head){
	top[u]=head,dfn[u]=++id,sum[id]=sum[id-1]+val[u];
	if(son[u])dfs2(son[u],head);
	for(int i=0;i<g[u].size();i++){
		int v=g[u][i].v;
		if(v==son[u]||v==fa[u])continue;
		dfs2(v,v);
	}  
}
void update(int &p,int l,int r,int x,int y){
	if(p<=last)tr[++tot]=tr[p],p=tot;
	if(x<=l&&r<=y)return void((tr[p].sum+=sum[r]-sum[l-1])+(tr[p].tg++));
	int mid=l+r>>1;
	if(x<=mid)update(tr[p].ls,l,mid,x,y);
	if(y>mid)update(tr[p].rs,mid+1,r,x,y);
	tr[p].sum=tr[tr[p].ls].sum+tr[tr[p].rs].sum+(sum[r]-sum[l-1])*tr[p].tg;
}
ll query(int p,int l,int r,int x,int y,ll cnt){
	if(!p)return cnt*(sum[min(r,y)]-sum[max(l,x)-1]);
	if(x<=l&&r<=y)return tr[p].sum+cnt*(sum[r]-sum[l-1]);
	int mid=l+r>>1;ll res=0;
	if(x<=mid)res+=query(tr[p].ls,l,mid,x,y,cnt+tr[p].tg);
	if(y>mid)res+=query(tr[p].rs,mid+1,r,x,y,cnt+tr[p].tg);
	return res
}
void Add(int u,int i){
	rt[i]=rt[i-1],last=tot;
	while(u){
		update(rt[i],1,n,dfn[top[u]],dfn[u]);
		u=fa[top[u]];
	}
}
ll Sum(int x,int u){
	ll q=sum_dep[x]+dep[u]*(ll)x;
	while(u){
		q-=2*query(rt[x],1,n,dfn[top[u]],dfn[u],0);
		u=fa[top[u]];
	}
	return q;
}
signed main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	for(int i=1,u,v,w;i<n;i++){
		scanf("%d%d%d",&u,&v,&w);
		g[u].push_back({v,w});
		g[v].push_back({u,w});
	}
	dfs(1),dfs2(1,1);
	for(int i=1;i<=n;i++)sum_dep[i]=sum_dep[i-1]+dep[a[i]];
	for(int i=1;i<=n;i++)Add(a[i],i);
	ll lst=0;
	for(int op,x,y,z;m--;){
		scanf("%d%d",&op,&x),x^=lst;
		if(op^2){
			scanf("%d%d",&y,&z),y^=lst,z^=lst;
			lst=Sum(y,z)-Sum(x-1,z);
			printf("%lld\n",lst);
			lst&=mod;
		}else{
			swap(a[x],a[x+1]);
			sum_dep[x]=sum_dep[x-1]+dep[a[x]];
			if(tot<=lim)Add(a[x],x);
			else{
				tot=0;
				for(int i=1;i<=n;i++)Add(a[i],i);
			}
		}
	}
}
2022/1/25 16:15
加载中...