求助
查看原帖
求助
22930
Lolierl楼主2020/4/26 16:52

大致思路:分治求左右两区间内部和,然后跨过中点的区间,从m+1m + 1rr枚举右端点,用线段树维护[l,m][l, m]作为开头到当前这段区间的不同颜色个数,但是拍不上答案偏大

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm> 
#define ll long long
using namespace std; 

const int N = 1e5 + 5, M = 1e9 + 7; 
int n, a[N], w[N], s[N], mem[N], s2[N]; 
ll ans; 

struct tree
{
	int l, r, lazy; 
	ll sum; 
}b[N * 8]; 

void buildtree(int t, int l, int r)
{
	b[t].l = l; b[t].r = r; b[t].sum = b[t].lazy = 0; 
	if(l == r)return; 
	int m = (l + r) >> 1; 
	buildtree(t * 2, l, m); 
	buildtree(t * 2 + 1, m + 1, r); 
}

void pushdown(int t)
{
	int x = b[t].lazy; 
	if(x)
	{
		b[t].sum = (b[t].sum + 1ll * x * (b[t].r - b[t].l + 1) % M) % M;  
		b[t * 2].lazy += x; b[t * 2 + 1].lazy += x; 
		b[t].lazy = 0; 
	}
}

void update(int t, int l, int r, int k, int opt)
{
	if(l > r)return; 
	pushdown(t); 
	if(b[t].l >= l && b[t].r <= r)
	{
		if(opt == 0)b[t].lazy += k; else b[t].sum = k; 
		return; 
	}
	int m = (b[t].l + b[t].r) >> 1; 
	if(l <= m)update(t * 2, l, r, k, opt); 
	if(r > m)update(t * 2 + 1, l, r, k, opt); 
	pushdown(t * 2); pushdown(t * 2 + 1); 
	b[t].sum = (b[t * 2].sum + b[t * 2 + 1].sum) % M; 
}

int query(int t, int l, int r)
{
	if(l > r)return 0; 
	pushdown(t); 
	if(b[t].l >= l && b[t].r <= r)
		return b[t].sum; 
	
	int m = (b[t].l + b[t].r) >> 1;
	ll ret = 0; 
	if(l <= m)ret += query(t * 2, l, r); 
	if(r > m)ret += query(t * 2 + 1, l, r); 
	return ret; 
}
void solve(int l, int r)
{
	if(l > r)return; 
	if(l == r){ans++; return; }
	
	int m = (l + r) >> 1; 
	solve(l, m); solve(m + 1, r); 
	for(int i = m + 1; i <= r; i++)mem[a[i]] = l - 1; 
	for(int i = l; i <= m; i++)mem[a[i]] = i; 

	int p = 0; 
	ll ret = 0; 
	for(int i = m; i >= l; i--)
	{
		if(!s[a[i]])p++; 
		update(1, i, i, p, 1); 
		ret = (ret + 1ll * p * p % M) % M; 
		s[a[i]] = 1; 
	}
	for(int i = m + 1; i <= r; i++)
	{
		if(!s2[a[i]])ret = (ret + 2 * query(1, mem[a[i]] + 1, m) + m - mem[a[i]]) % M; 
		ans = (ans + ret) % M; 
		update(1, mem[a[i]] + 1, m, 1, 0); 
		s2[a[i]] = 1; 
	}
	for(int i = l; i <= r; i++)
		s[a[i]] = mem[a[i]] = s2[a[i]] = 0;  
}
int main()
{
	scanf("%d", &n); 
	for(int i = 1; i <= n; i++)
	{
		scanf("%d", &a[i]); 
		w[i] = a[i]; 
	}
	sort(w + 1, w + n + 1); 
	for(int i = 1; i <= n; i++)
		a[i] = lower_bound(w + 1, w + n + 1, a[i]) - w; 
	
	if(n <= 1000)
	{
		int ans = 0, p = 0; 
		for(int l = 1; l <= n; l++)
		{
			for(int i = 1; i <= n; i++)
				s[i] = 0; 
			p = 0; 
			for(int r = l; r <= n; r++)
			{
				if(!s[a[r]])p++; 
				ans = (ans + 1ll * p * p % M) % M; 
				s[a[r]] = 1; 
			}
		}
		printf("%lld\n", ans); 
		return 0; 
	}
	buildtree(1, 1, n); 
	solve(1, n); 
	printf("%lld\n", ans); 
	return 0; 
}
2020/4/26 16:52
加载中...