QAQ左偏树 MLE+WA 78pts 求大佬指教
查看原帖
QAQ左偏树 MLE+WA 78pts 求大佬指教
393190
aldol_reaction楼主2021/4/3 21:04

码风很清晰的qwq,不知道哪里的问题会MLE,想的我心态炸了qwq

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define NDEBUG
#include <assert.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;

inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    return x * f;
}

#define endl '\n'
#define rd read()
#define pb push_back
#define mst(a, b) memset((a), (b), sizeof(a));
#define inf 0x3f3f3f3f
#define linf 0x3f3f3f3f3f3f3f3f
#define mod ((int)1e9+7)
#define maxn (int)(2e5+5)

int n, m, rt;
ll ans;
int l[maxn], fa[maxn];
bool kill[maxn];
vector<int> e[maxn];

struct node {int x, val, dist, ls, rs; ll tot;} p[maxn]; //tot是以node为根的树中的薪金总和, val是以node为根的树中的雇佣人数总和,x是node结点的薪金

int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);}

int merge(int x, int y) {
    if(!x || !y) return x + y;
    if(p[x].x < p[y].x)
        swap(x, y);
    p[x].tot += p[y].tot;
    p[x].val += p[y].val;
    p[x].rs = merge(p[x].rs, y);
    if(p[p[x].ls].dist < p[p[x].rs].dist)
        swap(p[x].ls, p[x].rs);
    p[x].dist = p[p[x].rs].dist + 1;
    return x;
}

void inp() {
    cin >> n >> m;
    for(int i = 1; i <= n; ++i) {
        int u = rd;
        e[u].pb(i);
        p[i].x = rd;
        p[i].tot = p[i].x;
        p[i].val = 1;
        l[i] = rd;
        fa[i] = i;
    }
}

int del(int x) {
    if(kill[x]) return -1;
    x = find(x);
    kill[x] = true;
    fa[p[x].ls] = fa[p[x].rs] = fa[x] = merge(p[x].ls, p[x].rs);
    p[x].dist = p[x].ls = p[x].rs = 0;
    return p[x].x;
}

void dfs(int u) {
    int rt = u;
    if(p[rt].x <= m && e[u].size() == 0) {
        ans = max(ans, (ll)l[u]);
    }
    for(unsigned int i = 0; i < e[u].size(); ++i) {
        bool flag = false;
        int v = e[u][i];
        dfs(v);
        rt = find(rt), v = find(v);
        rt = fa[rt] = fa[v] = merge(rt, v);
        while(p[rt].tot > m) {
            int x = del(rt);
            if(x == 0 || x == -1 || p[rt].x == p[rt].tot) {
                flag = true;
                break;
            }
            p[rt].tot -= x;
            --p[rt].val;
        }
        if(flag) continue;
        ll ans1 = (ll)p[rt].val * l[u];
        ans = max(ans, ans1);
    }
}

int main() {
    inp();
    dfs(0);
    cout << ans;




















































    return 0;
}
2021/4/3 21:04
加载中...