已ac,但是对re发问
查看原帖
已ac,但是对re发问
831588
WanderFreeFish楼主2025/8/4 16:57

这代码 c++11 和 c++14 就会 RE,换成 c++17 就过了,调了两个小时,求解释。

#include <algorithm>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define left tr[root].ls
#define right tr[root].rs

const int MAXN = 5e4 + 10;
const int inf = 2147483647;

inline int read () {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + (ch ^ 48);
        ch = getchar();
    }
    return x * f;
}
inline void write (int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int a[MAXN];

struct AVLtree {
    struct node {
        int ls, rs, val, height, cnt, sz;
    };
    int tot, allroot;
    
    std::vector <node> tr;
    
    AVLtree () : tr(1) {
    	allroot = tot = 0;
    	tr[0] = {0, 0, 0, 0, 0, 0};
    }
    
    int newnode (int val) {
    	tot++;
        tr.push_back({0, 0, val, 1, 1, 1});
        return tot;
    }
    
    int getbalance (int root) {
        return tr[left].height - tr[right].height;
    }
    
    void push_up (int root) {
        tr[root].height = std::max(tr[left].height, tr[right].height) + 1;
        tr[root].sz = tr[left].sz + tr[right].sz + tr[root].cnt;
    }
    
    int L (int root) {
        int R = right;
        right = tr[R].ls, tr[R].ls = root;
        push_up(root); push_up(R);
        return R;
    }
    
    int R (int root) {
        int L = left;
        left = tr[L].rs, tr[L].rs = root;
        push_up(root); push_up(L);
        return L;
    }
    
    int balance (int root) {
        int factor = getbalance(root);
        if (factor > 1) {
            if (getbalance(left) < 0)
                left = L(left);
            return R(root);
        }
        else if (factor < -1) {
            if (getbalance(right) > 0)
                right = R(right);
            return L(root);
        }
        else return root;
    }
    
    int findmin (int root) {
        while (left)
        	root = left;
        
        return root;
    }
    
    int insert (int root, int val) {
        if (root == 0) return newnode(val);
        else if (tr[root].val > val)
            left = insert(left, val);
        else if (tr[root].val < val)
            right = insert(right, val);
        else {
            tr[root].cnt++, tr[root].sz++;
            return root;
        }
        push_up(root);
        return balance(root);
    }
    
    int del (int root, int val) {
        if (root == 0) return 0;
        if (tr[root].val > val)
            left = del(left, val);
        else if (tr[root].val < val)
            right = del(right, val);
        else {
            if (tr[root].cnt > 1) {
                tr[root].cnt--, tr[root].sz--;
                return root;
            }
            else if (left == 0 || right == 0)
                return left | right;
            
            int minnode = findmin(right);
            tr[root].val = tr[minnode].val;
            tr[root].cnt = 1;
            right = del(right, tr[minnode].val);
        }
        
        push_up(root);
        return balance(root);
    }
    
    int rank (int root, int val) {
        if (root == 0) return 0;
        else if (tr[root].val < val)
            return tr[left].sz + tr[root].cnt + rank(right, val);
        else return rank(left, val);
    }
    
    int kth (int root, int k) {
        while (root) {
            if (tr[left].sz >= k)
                root = left;
            else if (tr[left].sz + tr[root].cnt >= k)
                return tr[root].val;
            else
                k -= tr[left].sz + tr[root].cnt, root = right;
        }
        return inf;
    }
    
    int pre (int root,int val) {
        int res = -inf;
        while (root) {
            if (tr[root].val < val) res = std::max(res, tr[root].val), root = right;
            else root = left;
        }
        return res;
    }
    
    int suf (int root, int val) {
        int res = inf;
        while (root) {
            if (tr[root].val > val)
                res = std::min(res, tr[root].val), root = left;
            else root = right;
        }
        
        return res;
    }
};

#define mid ((l + r) >> 1)
#define ls (root << 1)
#define rs (root << 1 | 1)
#define lson ls, l, mid
#define rson rs, mid + 1, r

int n, m;

struct segment_tree {
	std::vector <AVLtree> tr;
	
	segment_tree (int l = 0) : tr(l << 3) {}
	
	void insert (int root, int l, int r, int pos, int val) {
		tr[root].allroot = tr[root].insert(tr[root].allroot, val);
		if (l == r) return;
		else if (pos <= mid) insert(lson, pos, val);
		else insert(rson, pos, val);
	}
	
	void del (int root, int l, int r, int pos, int val) {
		tr[root].allroot = tr[root].del(tr[root].allroot, val);
		if (l == r) return;
		else if (pos <= mid) del(lson, pos, val);
		else del(rson, pos, val);
	}
	
	int rank (int root, int l, int r, int ql, int qr, int val) {
		if (ql <= l && r <= qr)
			return tr[root].rank(tr[root].allroot, val);
		else if (qr <= mid) return rank(lson, ql, qr, val);
		else if (ql > mid) return rank(rson, ql, qr, val);
		else return rank(lson, ql, qr, val) + rank(rson, ql, qr, val);
	}
	
	int kth (int ql, int qr, int k) {
		int l = 0, r = 1e8;
		
		while (l <= r) {
			int x = l + r >> 1;
			if (rank(1, 1, n, ql, qr, x + 1) < k)
				l = x + 1;
			else r = x - 1;
		}
		
		return l;
	}
	
	int pre (int root, int l, int r, int ql, int qr, int val) {
		if (ql <= l && r <= qr)
			return tr[root].pre(tr[root].allroot, val);
		
		if (qr <= mid)
			return pre(lson, ql, qr, val);
		else if (ql > mid) return pre(rson, ql, qr, val);
		else return std::max(pre(lson, ql, qr, val), pre(rson, ql, qr, val));
	}
	
	int suf (int root, int l, int r, int ql, int qr, int val) {
		if (ql <= l && r <= qr)
			return tr[root].suf(tr[root].allroot, val);
		
		if (qr <= mid) return suf(lson, ql, qr, val);
		else if (ql > mid) return suf(rson, ql, qr, val);
		else return std::min(suf(lson, ql, qr, val), suf(rson, ql, qr, val));
	}
}tree;

int main () {
	
	n = read(); m = read();
	tree = segment_tree (n + 10);
	
	for (int i = 1; i <= n; i++) {
		a[i] = read();
		tree.insert(1, 1, n, i, a[i]);
	}
	
	while (m--) {
		int opt, l, r, k, pos; opt = read();
		if (opt == 1) {
			l = read(), r = read(), k = read();
			write(tree.rank(1, 1, n, l, r, k) + 1); putchar('\n');
		}
		else if (opt == 2) {
			l = read(), r = read(), k = read();
			write(tree.kth(l, r, k)); putchar('\n');
		}
		else if (opt == 3) {
			pos = read(), k = read();
			tree.del(1, 1, n, pos, a[pos]);
			tree.insert(1, 1, n, pos, k);
			a[pos] = k;
		}
		else if (opt == 4) {
			l = read(), r = read(), k = read();
			write(tree.pre(1, 1, n, l, r, k)); putchar('\n');
		}
		else {
			l = read(), r = read(), k = read();
			write(tree.suf(1, 1, n, l, r, k)); putchar('\n');
		}
	}
	
	
	return 0;
}
2025/8/4 16:57
加载中...