似乎多建了很多节点,但是不知道是哪里出了锅 QAQ
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
struct node
{
int lson,rson;
int sum_;
} tree[(N<<2)+N*17];
int root[N],tot=0;//版本号用节点编号
int head[N],ver[N<<1],nxt[N<<1],_tot=0;
void add(int x,int y)
{
ver[++_tot]=y;
nxt[_tot]=head[x];
head[x]=_tot;
}
int n,m;
int arr[N];
int dpt[N],id[N],size_[N],top[N],cnt=0;//树剖相关
int fa[N],son[N];
int MAX;
vector<int> nums;
int find(int x)
{
return lower_bound(nums.begin(),nums.end(),x)-nums.begin();
}
#define lnode tree[node].lson
#define rnode tree[node].rson
int build(int start,int end)
{
int node=++tot;
tree[node].sum_=0;
if(start==end) return node;
int mid=start+end>>1;
lnode=build(start,mid);
rnode=build(mid+1,end);
return node;
}
#define lnode1 tree[node1].lson
#define rnode1 tree[node1].rson
int insert(int node,int start,int end,int x)
{
int node1=++tot;
tree[node1]=tree[node];
if(start==end)
{
tree[node1].sum_++;
return node1;
}
int mid=start+end>>1;
if(x<=mid) lnode1=insert(lnode,start,mid,x);
else rnode1=insert(rnode,mid+1,end,x);
tree[node1].sum_=tree[lnode1].sum_+tree[rnode1].sum_;
return node1;
}
#define lnode2 tree[node2].lson
#define rnode2 tree[node2].rson
#define lnode3 tree[node3].lson
#define rnode3 tree[node3].rson
int query(int node,int node1,int node2,int node3,int start,int end,int k)
{
if(start==end) return start;
int mid=start+end>>1;
int tmp=tree[node].sum_+tree[node1].sum_-tree[node2].sum_-tree[node3].sum_;//树上差分
if(k<=tmp) return query(lnode,lnode1,lnode2,lnode3,start,mid,k);
else return query(rnode,rnode1,rnode2,rnode3,mid+1,end,k-tmp);
}
void dfs1(int x,int f)
{
root[x]=insert(root[f],0,MAX,find(arr[x]));
dpt[x]=dpt[f]+1;
fa[x]=f;
size_[x]=1;
for(int i=head[x]; i; i=nxt[i])
{
int y=ver[i];
if(y==f) continue;
dfs1(y,x);
size_[x]+=size_[y];
if(size_[son[x]]<size_[y]) son[x]=y;
}
}
void dfs2(int x,int t)
{
id[x]=++cnt;
top[x]=t;
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x]; i; i=nxt[i])
{
int y=ver[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dpt[top[x]]<dpt[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dpt[x]>dpt[y]) swap(x,y);
return x;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1; i<=n; i++)
{
scanf("%d",&arr[i]);
nums.push_back(arr[i]);
}
sort(nums.begin(),nums.end());
nums.erase(unique(nums.begin(),nums.end()),nums.end());
for(int i=1; i<n; i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
MAX=nums.size()-1;
root[0]=build(0,MAX);
dfs1(1,0);
dfs2(1,1);
int last=0;
for(int i=1; i<=m; i++)
{
int x,y,k;
scanf("%d%d%d",&x,&y,&k);
int LCA=lca(x,y);
last=nums[query(root[x^last],root[y],root[LCA],root[fa[LCA]],0,MAX,k)];
printf("%d\n",last);
}
return 0;
}