思路是删除某个元素时减少了 tmp 个逆序对
tmp = (删的时间晚于它,位置大于它,数值小于它)+(删的时间晚于它,位置小于它,数值大于它)
然后样例过了提交全wa
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, Q = 1e5 + 10;
int n, k, m, tr[N], dy[N];
ll res[Q];
int lowbit(int x){ return x & -x; }
void insert(int x, int val){ for(; x <= n; x += lowbit(x)) tr[x] += val; }
int query(int x){ int ret = 0; for(; x; x -= lowbit(x)) ret += tr[x]; return ret; }
struct pos{
int x, y, z, val;
}tmp[N], flo[N];
bool cmp(pos a, pos b){
if(a.x != b.x) return a.x > b.x;
if(a.y != b.y) return a.y < b.y;
return a.z > b.z;
}
bool cmp2(pos a, pos b){
if(a.x != b.x) return a.x > b.x;
if(a.y != b.y) return a.y > b.y;
return a.z < b.z;
}
void cdq(int L, int R){
if(L == R) return;
int mid = (L + R) >> 1;
cdq(L, mid); cdq(mid + 1, R);
int l = L, r = mid + 1;
int len = 0;
while(l <= mid || r <= R){
if(r > R || l <= mid && flo[l].y <= flo[r].y){
insert(flo[l].z, flo[l].val);
tmp[++len] = flo[l];
l++;
} else {
tmp[++len] = flo[r];
res[flo[r].x] += query(n) - query(flo[r].z);
r++;
}
}
for(int i = L; i <= mid; i++)
insert(flo[i].z, -flo[i].val);
for(int i = 1; i <= len; i++)
flo[L - 1 + i] = tmp[i];
return;
}
void cdq2(int L, int R){
if(L == R) return;
int mid = (L + R) >> 1;
cdq2(L, mid); cdq2(mid + 1, R);
int l = L, r = mid + 1;
int len = 0;
while(l <= mid || r <= R){
if(r > R || l <= mid && flo[l].y >= flo[r].y){
insert(flo[l].z, flo[l].val);
tmp[++len] = flo[l];
l++;
} else {
tmp[++len] = flo[r];
res[flo[r].x] += query(flo[r].z);
r++;
}
}
for(int i = L; i <= mid; i++)
insert(flo[i].z, -flo[i].val);
for(int i = 1; i <= len; i++)
flo[L - 1 + i] = tmp[i];
return;
}
int main(){
// freopen("ze.txt", "r", stdin);
// freopen("cdq.txt", "w", stdout);
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++){
scanf("%d", &flo[i].z); dy[flo[i].z] = i;
flo[i].x = m + 1, flo[i].y = i, flo[i].val = 1;
}
for(int i = 1, x; i <= m; i++)
scanf("%d", &x), flo[dy[x]].x = i;
sort(flo + 1, flo + n + 1, cmp); cdq(1, n);
sort(flo + 1, flo + n + 1, cmp2); cdq2(1, n);
for(int i = m; i >= 1; i--) res[i] += res[i + 1];
for(int i = 1; i <= m; i++) printf("%lld\n", res[i]);
return 0;
}