RT 这是我的代码
#include <bits/stdc++.h>
#define N (int)(3e5+5)
#define ls (cur<<1)
#define rs (ls|1)
using namespace std;
struct node {
int x,y,val,id;
}a[N<<1];
struct edge {
int nex,to;
}e[N<<1];
bool vis[N];
int head[N],cnt,tot,fa[N],top[N],dep[N],son[N],id[N],sz[N];
int n,m,pre[N],lsh[N],ans[N],l_cnt;
int rd() {
int f=1,sum=0; char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)) {sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return sum*f;
}
void add(int x,int y) {
e[++cnt]=edge{head[x],y};
head[x]=cnt;
}
void dfs1(int x,int faa) {
dep[x]=dep[faa]+1; fa[x]=faa; sz[x]=1;
for(int i=head[x];i;i=e[i].nex) {
int y=e[i].to;
if(y==faa) continue;
dfs1(y,x); sz[x]+=sz[y];
if(sz[y]>sz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int tp) {
top[x]=tp; id[x]=++tot;
if(son[x]) dfs2(son[x],tp);
for(int i=head[x];i;i=e[i].nex) {
int y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int sum[N<<2];
void add(int cur,int l,int r,int x,int val) {
if(l==r) {
sum[cur]+=val;
return;
}
int mid=(l+r)>>1;
if(x<=mid) add(ls,l,mid,x,val);
else add(rs,mid+1,r,x,val);
sum[cur]=sum[ls]+sum[rs];
}
int query(int cur,int l,int r,int cl,int cr) {
if(cl>cr) return 0;
if(cl<=l&&r<=cr) return sum[cur];
int mid=(l+r)>>1,res=0;
if(cl<=mid) res=query(ls,l,mid,cl,cr);
if(cr>mid) res+=query(rs,mid+1,r,cl,cr);
return res;
}
int qry(int x,int y) {
int res=0;
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) swap(x,y);//cout<<x<<" "<<y<<endl;
res+=query(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
return res+query(1,1,n,id[x],id[y]);
}
node q1[N],q2[N];
void solve(int L,int R,int l,int r) {
if(l>r) return;
if(L==R) {
for(int i=l;i<=r;i++) {
if(a[i].id==-1) continue;
ans[a[i].id]=lsh[L];
// cout<<lsh[L]<<" "<<a[i].val<<endl;
}
//cout<<L<<" "<<l<<" "<<r<<endl;
return;
}
int mid=(L+R)>>1,cnt1=0,cnt2=0;
for(int i=l;i<=r;i++) {
if(a[i].id==-1) {
if(a[i].val<=mid) q1[++cnt1]=a[i];
else q2[++cnt2]=a[i],add(1,1,n,id[a[i].x],a[i].y);
} else {
int res=qry(a[i].x,a[i].y);
if(res>=a[i].val) q2[++cnt2]=a[i];
else a[i].val-=res,q1[++cnt1]=a[i];
}
}
for(int i=1;i<=cnt2;i++) if(q2[i].id==-1) add(1,1,n,id[q2[i].x],-q2[i].y);
for(int i=1;i<=cnt1;i++) a[i+l-1]=q1[i];
for(int i=1;i<=cnt2;i++) a[i+l+cnt1-1]=q2[i];
solve(L,mid,l,l+cnt1-1); solve(mid+1,R,l+cnt1,r);
}
int main() {
n=rd(); m=rd();
int tot1=0,x,y;
for(int i=1;i<=n;i++) pre[i]=rd(),a[++tot1]=node{i,1,pre[i],-1},lsh[++l_cnt]=pre[i];
for(int i=1;i<n;i++) x=rd(),y=rd(),add(x,y),add(y,x);
dfs1(1,0); dfs2(1,1);
for(int i=1;i<=m;i++) {
x=rd();
if(x==0) {
x=rd(); y=rd();
a[++tot1]=node{x,-1,pre[x],-1};
a[++tot1]=node{x,1,y,-1};
lsh[++l_cnt]=pre[x]=y;
} else {
a[++tot1]=node{rd(),rd(),x,i};
vis[i]=1;
}
}
sort(lsh+1,lsh+1+l_cnt);// l_cnt=unique(lsh+1,lsh+1+l_cnt)-lsh-1;
for(int i=1;i<=tot1;i++) {
if(a[i].id==-1) a[i].val=lower_bound(lsh+1,lsh+1+l_cnt,a[i].val)-lsh;
}
solve(0,l_cnt,1,tot1);
for(int i=1;i<=m;i++) {
if(vis[i]) ans[i]?printf("%d\n",ans[i]):puts("invalid request!");
}
return 0;
}
solve(0,l_cnt,1,tot1);
为什么调用值域范围为[0,l_cnt]而不是[1,l_cnt],离散化出来之后不是从1开始的吗? qwq