求大佬帮忙优化下
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10;
struct Node{
int l,r;
int sum;
int ltag,rtag;
int tag;
inline void cover(int k){
tag=ltag=rtag=k;
sum=1;
}
}tr[N<<3];
struct Edge{
int v,next;
}edge[N<<1];
int head[N],top[N],dep[N],fa[N],id[N],sz[N],son[N],cnt,w[N],nw[N],tot;
inline int read()
{
int x=0,t=1;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
inline void add(int u,int v){
edge[++tot]={v,head[u]};
head[u]=tot;
}
int n,m;
inline void dfs1(int u,int father ,int depth){
dep[u]=depth;sz[u]=1;fa[u]=father;
for(int i=head[u];i;i=edge[i].next){
int v=edge[i].v;
if(v!=father){
dfs1(v,u,depth+1);
if(sz[son[u]]<sz[v])son[u]=v;
}
}
}
inline void dfs2(int u,int topf){
top[u]=topf;id[u]=++cnt;nw[cnt]=w[u];
if(!son[u])return ;
dfs2(son[u],topf);
for(int i=head[u];i;i=edge[i].next){
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
}
inline void push_up(Node &u,Node l,Node r){
u.ltag=l.ltag;
u.rtag=r.rtag;
u.sum=l.sum+r.sum;
if(l.rtag==r.ltag)u.sum--;
}
inline void push_up(int u){
push_up(tr[u],tr[u<<1],tr[u<<1|1]);
}
inline void build(int u,int l,int r){
tr[u]={l,r};
if(l==r){
tr[u].cover(nw[l]);
return;
}
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
push_up(u);
}
inline void push_down(int u){
if(tr[u].tag){
tr[u<<1].cover(tr[u].tag);
tr[u<<1|1].cover(tr[u].tag);
}
tr[u].tag=0;
}
inline void change(int u,int l,int r,int k){
if(l<=tr[u].l&&tr[u].r<=r){
tr[u].cover(k);
return;
}
push_down(u);
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid)change(u<<1,l,r,k);
if(r>mid)change(u<<1|1,l,r,k);
push_up(u);
}
inline Node query(int u,int l,int r){
if(l<=tr[u].l&&tr[u].r<=r){
return tr[u];
}
push_down(u);
int mid=tr[u].l+tr[u].r>>1;
if(r<=mid)return query(u<<1,l,r);
if(l>mid)return query(u<<1|1,l,r);
Node t,t1,t2;
t1=query(u<<1,l,r);t2=query(u<<1|1,l,r);
push_up(t,t1,t2);
return t;
}
inline void change_range(int u,int v,int k){
while(top[u]^top[v]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
change(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]<dep[v])swap(u,v);
change(1,id[v],id[u],k);
}
inline int query_range(int u,int v){
int res=0,lco=-1,rco=-1;
while(top[u]^top[v]){
if(dep[top[u]]<dep[top[v]]){
swap(u,v);
swap(lco,rco);
}
Node t=query(1,id[top[u]],id[u]);
res+=t.sum;
if(lco==t.rtag){
res--;
}
lco=t.ltag;
u=fa[top[u]];
}
if(dep[u]<dep[v]){
swap(u,v);
swap(lco,rco);
}
Node t=query(1,id[v],id[u]);
res+=t.sum;
if(lco==t.rtag)res--;
if(rco==t.ltag)res--;
return res;
}
int main(){
scanf("%d%d",&n,&m);
for(register int i=1;i<=n;i++){
w[i]=read();
}
for(register int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);
add(v,u);
}
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
char op[15];
int u,v,k;
for(register int i=1;i<=m;i++){
scanf("%s%d%d",op,&u,&v);
if(*op=='Q'){
printf("%d\n",query_range(u,v));
}else{
scanf("%d",&k);
change_range(u,v,k);
}
}
}