AC on #2, #11,看了题解,实在调不动了
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
const int N = 1e6 + 15, inf = 1e9 + 7, mod = 1e6;
const int I_LOVE_CCF = 1;
mt19937 rnd (114514);
int n, m, root, idx, x, y, z;
int a[N];
struct FHQ {
int l, r;
int size, upd, fan;
int sum, l0, r0, l1, r1;
int val, rd, ans1, ans0;
}tr[N];
inline int read (int &n) {
int x = 0, f = 1;
char ch = getchar ();
while (! isdigit (ch)) {
if (ch == '-') f = -1;
ch = getchar ();
}
while (isdigit (ch)) {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar ();
}
n = x * f;
return n;
}
int get_node (int val) {
tr[++ idx].val = val;
tr[idx].rd = rnd ();
tr[idx].size = 1;
tr[idx].sum = val;
tr[idx].l1 = tr[idx].r1 = tr[idx].ans1 = val;
tr[idx].l0 = tr[idx].r0 = tr[idx].ans0 = !val;
return idx;
}
void push_up (int u) {
tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + 1;
tr[u].sum = tr[tr[u].l].sum + tr[tr[u].r].sum + tr[u].val;
tr[u].l0 = tr[tr[u].l].l0, tr[u].l1 = tr[tr[u].l].l1;
tr[u].r0 = tr[tr[u].r].r0, tr[u].r1 = tr[tr[u].r].r1;
if (tr[u].val && tr[tr[u].l].size == tr[tr[u].l].sum)
tr[u].l1 = max (tr[u].l1, tr[tr[u].l].sum + tr[tr[u].r].l1 + 1);
if (tr[u].val && tr[tr[u].r].size == tr[tr[u].r].sum)
tr[u].r1 = max (tr[u].r1, tr[tr[u].r].sum + tr[tr[u].l].r1 + 1);
if (tr[u].val == 0 && tr[tr[u].l].sum == 0)
tr[u].l0 = max (tr[u].l0, tr[tr[u].l].size + tr[tr[u].r].l0 + 1);
if (tr[u].val == 0 && tr[tr[u].r].sum == 0)
tr[u].r0 = max (tr[u].r0, tr[tr[u].r].size + tr[tr[u].l].r0 + 1);
tr[u].ans0 = max (tr[tr[u].l].ans0, tr[tr[u].r].ans0);
tr[u].ans1 = max (tr[tr[u].l].ans1, tr[tr[u].r].ans1);
if (tr[u].val) tr[u].ans1 = max (tr[u].ans1, tr[tr[u].l].r1 + tr[tr[u].r].l1 + 1);
if (!tr[u].val) tr[u].ans0 = max (tr[u].ans0, tr[tr[u].l].r0 + tr[tr[u].r].l0 + 1);
}
void push_upd (int now, int k) {
tr[now].val = k;
if (k == 1) tr[now].sum = tr[now].size;
else tr[now].sum = 0;
tr[now].l1 = tr[now].r1 = tr[now].ans1 = tr[now].sum;
tr[now].l0 = tr[now].r0 = tr[now].ans0 = tr[now].size - tr[now].sum;
if (k == 1) tr[now].upd = 1;
else tr[now].upd = -1;
}
void push_fan (int now) {
tr[now].sum = tr[now].size - tr[now].sum;
tr[now].val ^= 1;
swap (tr[now].l1, tr[now].l0);
swap (tr[now].r1, tr[now].r0);
swap (tr[now].ans1, tr[now].ans0);
tr[now].fan ^= 1;
tr[now].upd = -tr[now].upd;
}
void push_down (int now) {
if (tr[now].upd) {
int x = tr[now].upd == 1 ? 1 : 0;
if (tr[now].l) {
push_upd (tr[now].l, x);
}
if (tr[now].r) {
push_upd (tr[now].r, x);
}
tr[now].upd = 0;
}
if (tr[now].fan) {
if (tr[now].l) {
push_fan (tr[now].l);
}
if (tr[now].r) {
push_fan (tr[now].r);
}
tr[now].fan = 0;
}
}
void split_rk (int now, int k, int &x, int &y) {
if (!now) {
x = y = 0;
return ;
}
push_down (now);
if (k <= tr[tr[now].l].size) {
y = now;
split_rk (tr[now].l, k, x, tr[now].l);
} else {
x = now;
split_rk (tr[now].r, k - tr[tr[now].l].size - 1, tr[now].r, y);
}
push_up (now);
}
int merge (int u, int v) {
push_down (u), push_down (v);
if (!u || !v) return u | v;
if (tr[u].rd < tr[v].rd) {
push_down (u);
tr[u].r = merge (tr[u].r, v);
push_up (u);
return u;
} else {
push_down (v);
tr[v].l = merge (u, tr[v].l);
push_up (v);
return v;
}
}
int build (int l, int r) {
if (l > r) return 0;
int mid = l + r >> 1;
int u = get_node (a[mid]);
tr[u].l = build (l, mid - 1);
tr[u].r = build (mid + 1, r);
push_up (u);
return u;
}
void update01 (int l, int r, int f) {
split_rk (root, r, x, z);
split_rk (x, l - 1, x, y);
push_upd (y, f);
root = merge (merge (x, y), z);
}
void updatefan (int l, int r) {
int x, y, z;
split_rk (root, r, x, z);
split_rk (x, l - 1, x, y);
push_fan (y);
root = merge (merge (x, y), z);
}
int query_sum (int l, int r) {
int x, y, z;
split_rk (root, r, x, z);
split_rk (x, l - 1, x, y);
int res = tr[y].sum;
root = merge (merge (x, y), z);
return res;
}
int query_ans (int l, int r) {
int x, y, z;
split_rk (root, r, x, z);
split_rk (x, l - 1, x, y);
int res = tr[y].ans1;
root = merge (merge (x, y), z);
return res;
}
signed main () {
read (n), read (m);
rep (i, 1, n) read (a[i]);
root = build (1, n);
while (m --) {
int op, l, r;
read (op), read (l), read (r);
++ l, ++ r;
if (op == 0) update01 (l, r, 0);
else if (op == 1) update01 (l, r, 1);
else if (op == 2) updatefan (l, r);
else if (op == 3) printf ("%d\n", query_sum (l, r));
else printf ("%d\n", query_ans (l, r));
push_down (root);
}
return 0;
}