求助卡常
查看原帖
求助卡常
128870
chen_qian楼主2021/5/3 18:02

RT,这个题写了一天了。通过的点不到0.2s就能过,而且按 Loj 数据规模来看,基本上都是 200000。但是就是无法通过本题。

#include<bits/stdc++.h>
#define pb push_back 
#define N 1000005 
#define INF 0X3F3F3F3F
using namespace std;
int n,m,L,R;
int val[N],head[N],idx;
struct edge{
	int v,w,next;
}e[N<<1];
struct node{
	int x,c,d,mxd;
}stk[N];
int top;
int ans=-INF;
bool cmp1(node x,node y){
	if(x.c!=y.c) return x.c<y.c;
	return x.d<y.d;
}
bool cmp2(node x,node y){
	if(x.mxd!=y.mxd) return x.mxd<y.mxd;
	if(x.c!=y.c) return x.c<y.c;
	return x.d<y.d;
}
void add(int u,int v,int w){
	e[++idx].v=v;
	e[idx].w=w;
	e[idx].next=head[u];
	head[u]=idx;
} 
bool vis[N];
int maxn[N],root,size[N],tot;
void getroot(int x,int f){
	size[x]=maxn[x]=1;
	for(int i=head[x];i;i=e[i].next){
		int y=e[i].v;
		if(y==f||vis[y]) continue;
		getroot(y,x);
		size[x]+=size[y];
		maxn[x]=max(maxn[x],size[y]);
	}
	maxn[x]=max(maxn[x],tot-size[x]);
	if(maxn[x]<maxn[root]) root=x;
}
int getdeep(int x,int f,int dep){
	int maxd=dep;
	for(int i=head[x];i;i=e[i].next){
		int y=e[i].v;
		if(y==f||vis[y]) continue;
		maxd=max(maxd,getdeep(y,x,dep+1));
	}
	return maxd;
}
vector<int> v;
int dep[N],sum[N];
struct Node{
	int x,f,pre;
};
void bfs(int x,int f,int pre){
	queue<Node> q;
	q.push((Node){x,f,pre});
	while(!q.empty()){
		int x=q.front().x,fa=q.front().f,pre=q.front().pre;
		q.pop();
		v.pb(x);
		for(int i=head[x];i;i=e[i].next){
			int y=e[i].v,z=e[i].w;
			if(vis[y]||y==fa) continue;
			if(z!=pre) sum[y]=sum[x]+val[z];
			else sum[y]=sum[x];
			dep[y]=dep[x]+1;
			q.push((Node){y,x,z});
		}
	}
}
int same[N],dif[N],len1,len2;
int q[N];
void calc1(){
	reverse(v.begin(),v.end());
	for(int i=1;i<=len1;i++) q[i]=0;
	int k=0,hd=1,tl=0;
	for(int i=0;i<v.size();i++){
		int ql=max(0,L-dep[v[i]]),qr=min(len1,R-dep[v[i]]);
		if(ql>len1) break;
		while(k<=qr){
			while(dif[k]>dif[q[tl]]&&tl) tl--;
			q[++tl]=k;
			k++;	
		}
		while(q[hd]<ql&&hd<=tl) hd++;
		if(tl>=hd){
			ans=max(ans,dif[q[hd]]+sum[v[i]]);
		}
	} 
}
void calc2(int w){
	for(int i=1;i<=len2;i++) q[i]=0;
	int k=1,hd=1,tl=0;
	for(int i=0;i<v.size();i++){
		int ql=max(1,L-dep[v[i]]),qr=min(len2,R-dep[v[i]]);
		if(ql>len2) break;
		while(k<=qr){
			while(same[k]>same[q[tl]]&&tl) tl--;
			q[++tl]=k;
			k++;	
		}
		while(q[hd]<ql&&hd<=tl) hd++;
		if(tl>=hd){
			ans=max(ans,same[q[hd]]+sum[v[i]]-w);
		}
	} 
}
void update(){
	for(int i=1;i<=max(len1,len2);i++) dif[i]=max(dif[i],same[i]);
	len1=max(len1,len2);
	for(int i=1;i<=len2;i++) same[i]=-INF; 
	len2=0;
} 
void solve(int x){
	top=0;
	for(int i=head[x];i;i=e[i].next){
		int y=e[i].v,z=e[i].w;
		if(vis[y]) continue;
		stk[++top]=(node){y,z,getdeep(y,x,1),0}; 
	}
	sort(stk+1,stk+top+1,cmp1);
	stk[top].mxd=stk[top].d;
	for(int i=top-1;i>=1;i--){
		if(stk[i].c==stk[i+1].c) stk[i].mxd=stk[i+1].mxd;
		else stk[i].mxd=stk[i].d;
	} 
	sort(stk+1,stk+top+1,cmp2);
	for(int i=0;i<=stk[top].mxd;i++) same[i]=dif[i]=-INF;
	dif[0]=0;
	len1=0,len2=0;
	for(int i=1;i<=top;i++){
		if(stk[i].c!=stk[i-1].c) update();
		int y=stk[i].x;
		v.clear();
		sum[y]=val[stk[i].c];
		dep[y]=1;
		bfs(y,x,stk[i].c);
		calc1();
		calc2(val[stk[i].c]);
		for(int j=0;j<v.size();j++) same[dep[v[j]]]=max(same[dep[v[j]]],sum[v[j]]); 
		len2=max(len2,dep[v[0]]);
	}
}
void divide(int x){
	vis[x]=1;
	solve(x);
	for(int i=head[x];i;i=e[i].next){
		int y=e[i].v;
		if(vis[y]) continue;
		tot=maxn[0]=size[y];
		root=0;
		getroot(y,x);
		divide(root);
	}
}
int main(){
	//freopen("data.in","r",stdin);
	//freopen("my.out","w",stdout);
	scanf("%d%d%d%d",&n,&m,&L,&R);	
	for(int i=1;i<=m;i++) scanf("%d",&val[i]);
	for(int i=1;i<n;i++){
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		add(u,v,w);
		add(v,u,w);
	}
	maxn[0]=tot=n;
	root=0;
	getroot(1,0);
	divide(root);
	printf("%d\n",ans);
	return 0;
}
/*
13 1 2 4
-43 
9 6 1
13 5 1
3 5 1
10 5 1
1 4 1
9 1 1
5 8 1
9 3 1
2 3 1
9 11 1
8 7 1
3 12 1
*/

2021/5/3 18:02
加载中...