大致思路:分治求左右两区间内部和,然后跨过中点的区间,从m+1到r枚举右端点,用线段树维护[l,m]作为开头到当前这段区间的不同颜色个数,但是拍不上答案偏大
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 1e5 + 5, M = 1e9 + 7;
int n, a[N], w[N], s[N], mem[N], s2[N];
ll ans;
struct tree
{
int l, r, lazy;
ll sum;
}b[N * 8];
void buildtree(int t, int l, int r)
{
b[t].l = l; b[t].r = r; b[t].sum = b[t].lazy = 0;
if(l == r)return;
int m = (l + r) >> 1;
buildtree(t * 2, l, m);
buildtree(t * 2 + 1, m + 1, r);
}
void pushdown(int t)
{
int x = b[t].lazy;
if(x)
{
b[t].sum = (b[t].sum + 1ll * x * (b[t].r - b[t].l + 1) % M) % M;
b[t * 2].lazy += x; b[t * 2 + 1].lazy += x;
b[t].lazy = 0;
}
}
void update(int t, int l, int r, int k, int opt)
{
if(l > r)return;
pushdown(t);
if(b[t].l >= l && b[t].r <= r)
{
if(opt == 0)b[t].lazy += k; else b[t].sum = k;
return;
}
int m = (b[t].l + b[t].r) >> 1;
if(l <= m)update(t * 2, l, r, k, opt);
if(r > m)update(t * 2 + 1, l, r, k, opt);
pushdown(t * 2); pushdown(t * 2 + 1);
b[t].sum = (b[t * 2].sum + b[t * 2 + 1].sum) % M;
}
int query(int t, int l, int r)
{
if(l > r)return 0;
pushdown(t);
if(b[t].l >= l && b[t].r <= r)
return b[t].sum;
int m = (b[t].l + b[t].r) >> 1;
ll ret = 0;
if(l <= m)ret += query(t * 2, l, r);
if(r > m)ret += query(t * 2 + 1, l, r);
return ret;
}
void solve(int l, int r)
{
if(l > r)return;
if(l == r){ans++; return; }
int m = (l + r) >> 1;
solve(l, m); solve(m + 1, r);
for(int i = m + 1; i <= r; i++)mem[a[i]] = l - 1;
for(int i = l; i <= m; i++)mem[a[i]] = i;
int p = 0;
ll ret = 0;
for(int i = m; i >= l; i--)
{
if(!s[a[i]])p++;
update(1, i, i, p, 1);
ret = (ret + 1ll * p * p % M) % M;
s[a[i]] = 1;
}
for(int i = m + 1; i <= r; i++)
{
if(!s2[a[i]])ret = (ret + 2 * query(1, mem[a[i]] + 1, m) + m - mem[a[i]]) % M;
ans = (ans + ret) % M;
update(1, mem[a[i]] + 1, m, 1, 0);
s2[a[i]] = 1;
}
for(int i = l; i <= r; i++)
s[a[i]] = mem[a[i]] = s2[a[i]] = 0;
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
w[i] = a[i];
}
sort(w + 1, w + n + 1);
for(int i = 1; i <= n; i++)
a[i] = lower_bound(w + 1, w + n + 1, a[i]) - w;
if(n <= 1000)
{
int ans = 0, p = 0;
for(int l = 1; l <= n; l++)
{
for(int i = 1; i <= n; i++)
s[i] = 0;
p = 0;
for(int r = l; r <= n; r++)
{
if(!s[a[r]])p++;
ans = (ans + 1ll * p * p % M) % M;
s[a[r]] = 1;
}
}
printf("%lld\n", ans);
return 0;
}
buildtree(1, 1, n);
solve(1, n);
printf("%lld\n", ans);
return 0;
}