RT,这个题写了一天了。通过的点不到0.2s就能过,而且按 Loj 数据规模来看,基本上都是 200000。但是就是无法通过本题。
#include<bits/stdc++.h>
#define pb push_back
#define N 1000005
#define INF 0X3F3F3F3F
using namespace std;
int n,m,L,R;
int val[N],head[N],idx;
struct edge{
int v,w,next;
}e[N<<1];
struct node{
int x,c,d,mxd;
}stk[N];
int top;
int ans=-INF;
bool cmp1(node x,node y){
if(x.c!=y.c) return x.c<y.c;
return x.d<y.d;
}
bool cmp2(node x,node y){
if(x.mxd!=y.mxd) return x.mxd<y.mxd;
if(x.c!=y.c) return x.c<y.c;
return x.d<y.d;
}
void add(int u,int v,int w){
e[++idx].v=v;
e[idx].w=w;
e[idx].next=head[u];
head[u]=idx;
}
bool vis[N];
int maxn[N],root,size[N],tot;
void getroot(int x,int f){
size[x]=maxn[x]=1;
for(int i=head[x];i;i=e[i].next){
int y=e[i].v;
if(y==f||vis[y]) continue;
getroot(y,x);
size[x]+=size[y];
maxn[x]=max(maxn[x],size[y]);
}
maxn[x]=max(maxn[x],tot-size[x]);
if(maxn[x]<maxn[root]) root=x;
}
int getdeep(int x,int f,int dep){
int maxd=dep;
for(int i=head[x];i;i=e[i].next){
int y=e[i].v;
if(y==f||vis[y]) continue;
maxd=max(maxd,getdeep(y,x,dep+1));
}
return maxd;
}
vector<int> v;
int dep[N],sum[N];
struct Node{
int x,f,pre;
};
void bfs(int x,int f,int pre){
queue<Node> q;
q.push((Node){x,f,pre});
while(!q.empty()){
int x=q.front().x,fa=q.front().f,pre=q.front().pre;
q.pop();
v.pb(x);
for(int i=head[x];i;i=e[i].next){
int y=e[i].v,z=e[i].w;
if(vis[y]||y==fa) continue;
if(z!=pre) sum[y]=sum[x]+val[z];
else sum[y]=sum[x];
dep[y]=dep[x]+1;
q.push((Node){y,x,z});
}
}
}
int same[N],dif[N],len1,len2;
int q[N];
void calc1(){
reverse(v.begin(),v.end());
for(int i=1;i<=len1;i++) q[i]=0;
int k=0,hd=1,tl=0;
for(int i=0;i<v.size();i++){
int ql=max(0,L-dep[v[i]]),qr=min(len1,R-dep[v[i]]);
if(ql>len1) break;
while(k<=qr){
while(dif[k]>dif[q[tl]]&&tl) tl--;
q[++tl]=k;
k++;
}
while(q[hd]<ql&&hd<=tl) hd++;
if(tl>=hd){
ans=max(ans,dif[q[hd]]+sum[v[i]]);
}
}
}
void calc2(int w){
for(int i=1;i<=len2;i++) q[i]=0;
int k=1,hd=1,tl=0;
for(int i=0;i<v.size();i++){
int ql=max(1,L-dep[v[i]]),qr=min(len2,R-dep[v[i]]);
if(ql>len2) break;
while(k<=qr){
while(same[k]>same[q[tl]]&&tl) tl--;
q[++tl]=k;
k++;
}
while(q[hd]<ql&&hd<=tl) hd++;
if(tl>=hd){
ans=max(ans,same[q[hd]]+sum[v[i]]-w);
}
}
}
void update(){
for(int i=1;i<=max(len1,len2);i++) dif[i]=max(dif[i],same[i]);
len1=max(len1,len2);
for(int i=1;i<=len2;i++) same[i]=-INF;
len2=0;
}
void solve(int x){
top=0;
for(int i=head[x];i;i=e[i].next){
int y=e[i].v,z=e[i].w;
if(vis[y]) continue;
stk[++top]=(node){y,z,getdeep(y,x,1),0};
}
sort(stk+1,stk+top+1,cmp1);
stk[top].mxd=stk[top].d;
for(int i=top-1;i>=1;i--){
if(stk[i].c==stk[i+1].c) stk[i].mxd=stk[i+1].mxd;
else stk[i].mxd=stk[i].d;
}
sort(stk+1,stk+top+1,cmp2);
for(int i=0;i<=stk[top].mxd;i++) same[i]=dif[i]=-INF;
dif[0]=0;
len1=0,len2=0;
for(int i=1;i<=top;i++){
if(stk[i].c!=stk[i-1].c) update();
int y=stk[i].x;
v.clear();
sum[y]=val[stk[i].c];
dep[y]=1;
bfs(y,x,stk[i].c);
calc1();
calc2(val[stk[i].c]);
for(int j=0;j<v.size();j++) same[dep[v[j]]]=max(same[dep[v[j]]],sum[v[j]]);
len2=max(len2,dep[v[0]]);
}
}
void divide(int x){
vis[x]=1;
solve(x);
for(int i=head[x];i;i=e[i].next){
int y=e[i].v;
if(vis[y]) continue;
tot=maxn[0]=size[y];
root=0;
getroot(y,x);
divide(root);
}
}
int main(){
//freopen("data.in","r",stdin);
//freopen("my.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&L,&R);
for(int i=1;i<=m;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
maxn[0]=tot=n;
root=0;
getroot(1,0);
divide(root);
printf("%d\n",ans);
return 0;
}
/*
13 1 2 4
-43
9 6 1
13 5 1
3 5 1
10 5 1
1 4 1
9 1 1
5 8 1
9 3 1
2 3 1
9 11 1
8 7 1
3 12 1
*/