求助捉虫-萎掉了的DP
查看原帖
求助捉虫-萎掉了的DP
93701
Morgen_Kornblume楼主2021/7/8 16:17

题意:

给定 1n1041\le n \le 10^4 个整数(1ai1061\le a_i \le 10^6),把他们划分成 1m1001\le m \le 100 个连续的段,每一段的权值记为在本段中出现的不同数字的个数,求能划分出的最大各段权值和。

我的思路:四边形不等式优化区间DP。

代码理论复杂度:O(nmlog(n))O ( n \cdot m \cdot log(n) )

但是遭遇了大面积TLE,求捉虫

代码如下:

#include<iostream>
#include<algorithm>
#include<set>
#include<map>
#pragma GCC optimize(3)
#include<cstring>
using namespace std;

int n,m;
int a[10010];
set<int>uni;
map<int,int>ord;
int w[10010];

struct segement_tree{
	
	int dat[100010];	
	int tag[100010];
	
	void build(int pos,int l,int r){
		tag[pos]=0;
		if(l==r){
			dat[pos]=w[l];
			return;
		}
		int mid=(l+r)>>1;
		build(pos*2,l,mid);
		build(pos*2+1,mid+1,r);
		dat[pos]=dat[pos*2]+dat[pos*2+1];
	}
	
	int query(int pos,int l,int r,int tar){
		if(tag[pos]){
			dat[pos]+=tag[pos]*(r-l+1);
			if(l!=r){
				tag[pos*2]+=tag[pos];
				tag[pos*2+1]+=tag[pos];
			}
			tag[pos]=0;
		}
		if(l==r)return dat[pos];
		int mid=(l+r)>>1;
		if(tar<=mid){
			return query(pos*2,l,mid,tar);
		}
		else{
			return query(pos*2+1,mid+1,r,tar);
		}
	}
	
	void add(int pos,int l,int r,int tl,int tr,int val){
		if(l==tl&&r==tr){
			tag[pos]+=val;
			return;
		}
		dat[pos]+=(tr-tl+1)*val;
		int mid=(l+r)>>1;
		if(tl<=mid){
			add(pos*2,l,mid,tl,min(mid,tr),val);
		}
		if(tr>mid){
			add(pos*2+1,mid+1,r,max(tl,mid+1),tr,val);
		}
	}
	
}sg;

int f[10010],bst[10010];
int last[10010];
int pre[10010];
int occur[10010];

int main(){
	//freopen("dat.in","r",stdin);
	//freopen("ST.out","w",stdout);
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	cin>>n>>m;
	
	for(int i=1;i<=n;i++){
		cin>>a[i];
		if(uni.find(a[i])==uni.end()){
			uni.insert(a[i]);
		}
	}
	
	int tot=0;
	for(int tmp:uni){
		ord[tmp]=++tot;
	}
	
	for(int i=1;i<=n;i++){
		a[i]=ord[a[i]];
		pre[i]=last[a[i]];
		last[a[i]]=i;
	}
	memset(bst,0,sizeof(bst));
	for(int sgt=1;sgt<=m;sgt++){
		memset(occur,0,sizeof(occur));
		memset(w,0,sizeof(w));
		for(int i=n;i>=sgt;i--){
			if(!occur[a[i]]){
				w[i]=w[i+1]+1;
			}
			else w[i]=w[i+1];
			occur[a[i]]++;
		}
		sg.build(1,1,n);
		bst[n+1]=n;
		for(int i=n;i>=sgt;i--){
			int maxx=0,bt=0;
			for(int k=bst[i+1];k>=sgt&&k>=bst[i];k--){
				int wt=sg.query(1,1,n,k);//single point query
				if(f[k-1]+wt>maxx){
					maxx=f[k-1]+wt;
					bt=k;
				}
			}
			bst[i]=bt;
			f[i]=maxx;
			sg.add(1,1,n,pre[i]+1,i,-1);//segment revise
		}
	}
	
	cout<<f[n];
	
	return 0;
}
2021/7/8 16:17
加载中...