重金50元求解大佬帮忙看一下我代码中哪里出现问题了!调了2天硬是没有调出来
查看原帖
重金50元求解大佬帮忙看一下我代码中哪里出现问题了!调了2天硬是没有调出来
291976
quanjun楼主2022/12/8 13:56

首先,我的想法是这样的:

使用一个 splay tree 来维护这些冰山的信息。

对于每次操作的 xxyy

  1. 先将所有节点的权值增加 xx(这步操作只需要修改根节点的信息,然后进行一下懒惰标记即可,以后会 push_down 下去的)
  2. 然后增加一个权值为 yy 的点(如果权值为 yy 的点本身就存在的话,就将其数量增加 11
  3. 先将权值 >0\gt 0 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 0\le 0
  4. 再将权值 k\le k 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 >k\gt k 的,这些节点需要拆解为若干个权值为 kk 的节点和若干个权值为 11 的节点,操作过程为:
    1. 首先需要为每一个节点维护两个信息:
      • sum 表示以该节点为根节点的子树中所有节点权值与数量的乘积之和,比如:如果该节点所在子树中有 22 个权值为 33 的点,66 个权值为 55 的点,则该节点的 sum 为 2×3+6×5=362 \times 3 + 6 \times 5 = 36
      • cnt 表示以该节点为根节点的自述中所有节点的数量,比如:如果该绩点所在子树中有 22 个权值为 33 的点,66 个权值为 55 的点,则该节点的 cnt 值为 2+6=82 + 6 = 8
    2. 然后,先计算出根节点的右儿子的 sum 值和 cnt 值,可以发现,这些节点中的每一个都会被分成 11 个权值为 kk 的节点以及(剩余的)若干个权值为 11 的节点,所以我们需要做的事情是:
      1. 删除根节点的右子树
      2. 插入 cnt 个权值为 kk 的节点
      3. 插入 sum - cnt * k 个权值为 11 的节点

首先先贴一下我的完整代码(由于比较多我就放到洛谷剪贴板里面了),完整代码链接:https://www.luogu.com.cn/paste/75szc379

为了方便大佬查看,我在这篇帖子最后也发了完整的代码。

具体实现时,首先由于我经常会用到变量小 k,所以我把题目描述中的 k 开成了大K(即 K)—— 即冰山的最大大小。

然后我定义了一个结构体来维护 splay tree 里面的节点信息,如下:

struct Node {
    int s[2], p;    // s[0] 左儿子 s[1] 右儿子 p 父节点
    long long v,    // 冰山体积
              num,  // 冰山个数
              cnt,  // 子树包含冰山个数
              sum,  // 子树包含冰山体积之和
              flag; // 懒惰标记

    Node() {};
    Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];

其中:

  • s[0]s[1] 分别表示左儿子和右儿子节点的编号(如果没有则为 00
  • v 表示当前节点的权值(对应的就是冰山的体积)
  • num 表示体积为 v 的冰山有多少个
  • cnt 表示以当前节点为根节点的子树中包含的冰山的总数
  • sum 表示以当前节点为根节点的子树中包含的冰山的总体积
  • flag 是懒惰标记,它表示当前节点的所有子节点的体积需要整体增加的量

这里需要注明的是,我自己做懒惰标记的习惯是:修改当前节点信息的同时进行懒惰标记,然后将懒惰标记传给子节点的时候也会修改子节点的信息,因为写线段树的时候都是这么写的就习惯了。(因为有的大佬可能习惯是 pushdown 的时候更新当前节点,但是我已经习惯更新当前节点并进行懒惰标记,然后 pushdown 的时候一方面更新子节点,另一方面把懒惰标记传给子节点)

然后因为具体实现是经常需要进行形如 a = (a + b) % MOD; 的操作,所以我添加了一个 add 函数方便写:

void add(long long &a, long long b) {
    a = (a + b % MOD) % MOD;
}

然后就是比较重要的 push up 和 push down 操作了。

push up

push_up 需要将子节点的信息更新到当前节点,只需要更新一下 cnt 和 sum

主要操作就是:

tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;

需要取一下模。

push down

push_down 需要将当前节点的懒惰标记(flag,对应的是冰山体积的整体增量)传递给子节点并且同时更新子节点。对于一个节点来说,当传递了一个值为 tmp 的体积增量时:

  • flag 会增加 tmp
  • v 会增加 tmp
  • sum 会增加 cnt * tmp(因为该节点的子树中所有节点对应的冰山体积都会增加 tmp

主要操作是:

void t_flag(int x, long long tmp) {
    if (x) {
        tr[x].flag += tmp;
        tr[x].v += tmp;
        add(tr[x].sum, tr[x].cnt * tmp);
    }
}

void push_down(int x) {
    if (tr[x].flag) {
        t_flag(tr[x].s[0], tr[x].flag);
        t_flag(tr[x].s[1], tr[x].flag);
        tr[x].flag = 0;
    }
}

旋转和 splay 操作

这部分的操作基本没有改动过,之前用它们解决过 AcWing 上的 splay 例题(原来的帖子:https://www.acwing.com/file_system/file/content/whole/index/content/7428637/

对应的代码(多了一个 f_s(p, u, k) 函数用来认亲(u 是 p 的儿子,其中 k = 0 表示 u 是 p 的左儿子;k = 1 表示 u 是 p 的右儿子):

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

插入操作

ins(v, num) 表示的是加入 num 个体积为 v 的冰山:

void ins(long long v, long long num) {
    int u = root, p = 0;
    while (u && tr[u].v != v) {
        push_down(u);
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (!u) {
        tr[u = ++idx] = Node(v, p);
        if (p) tr[p].s[v > tr[p].v] = u;
    }
    else
        push_down(u);   // 如果不是新建的节点,需要push_down一下
    add(tr[u].num, num);
    add(tr[u].cnt, num);
    add(tr[u].sum, num * v);
    splay(u, 0);
}

调整

这里有一个 check 函数,主要用来判断当前节点对应的冰山体积是否合法:

bool check(int u) {
    return tr[u].v > 0 && tr[u].v <= K;
}

然后就是 get1() 函数和 get2() 函数了。

get1() 函数对应的就是我上面说的第 3 步操作:

  1. 先将权值 >0\gt 0 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 0\le 0
void get1() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v > 0) {
            x = u;
            u = tr[u].s[0];
        }
        else u = tr[u].s[1];
    }
    if (x) {
        splay(x, 0);
        tr[x].s[0] = 0;
        push_up(x);
    }
}

get2() 函数对应的就是我上面说的第 4 步操作:

  1. 再将权值 k\le k 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 >k\gt k 的,这些节点需要拆解为若干个权值为 kk 的节点和若干个权值为 11 的节点

(具体操作的描述可以看最上面哈)

void get2() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v <= K) {
            x = u;
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    if (x) {
        splay(x, 0);
        int y = tr[x].s[1];
        if (y) {
            long long cnt = tr[y].cnt, sum = tr[y].sum;
            tr[x].s[1] = 0;
            push_up(x);
            ins(K, cnt);
            ins(1, (sum - cnt * K % MOD + MOD) % MOD);
        }
    }
}

完整的代码(只有 60 分)

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5;
const long long MOD = 998244353;

int n, m;
long long K;
struct Node {
    int s[2], p;    // s[0] 左儿子 s[1] 右儿子 p 父节点
    long long v,    // 冰山体积
              num,  // 冰山个数
              cnt,  // 子树包含冰山个数
              sum,  // 子树包含冰山体积之和
              flag; // 懒惰标记

    Node() {};
    Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];
int root, idx;

void add(long long &a, long long b) {
    a = (a + b % MOD) % MOD;
}

void push_up(int x) {
    tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
    tr[x].cnt %= MOD;
    tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;
    tr[x].sum %= MOD;
    tr[x].sum = (tr[x].sum + MOD) % MOD;
}

void t_flag(int x, long long tmp) {
    if (x) {
        tr[x].flag += tmp;
        tr[x].v += tmp;
        add(tr[x].sum, tr[x].cnt * tmp);
    }
}

void push_down(int x) {
    if (tr[x].flag) {
        t_flag(tr[x].s[0], tr[x].flag);
        t_flag(tr[x].s[1], tr[x].flag);
        tr[x].flag = 0;
    }
}

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

void ins(long long v, long long num) {
    int u = root, p = 0;
    while (u && tr[u].v != v) {
        push_down(u);
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (!u) {
        tr[u = ++idx] = Node(v, p);
        if (p) tr[p].s[v > tr[p].v] = u;
    }
    else
        push_down(u);   // 如果不是新建的节点,需要push_down一下
    add(tr[u].num, num);
    add(tr[u].cnt, num);
    add(tr[u].sum, num * v);
    splay(u, 0);
}

bool check(int u) {
    return tr[u].v > 0 && tr[u].v <= K;
}

void get1() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v > 0) {
            x = u;
            u = tr[u].s[0];
        }
        else u = tr[u].s[1];
    }
    if (x) {
        splay(x, 0);
        tr[x].s[0] = 0;
        push_up(x);
    }
}

void get2() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v <= K) {
            x = u;
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    if (x) {
        splay(x, 0);
        int y = tr[x].s[1];
        if (y) {
            long long cnt = tr[y].cnt, sum = tr[y].sum;
            tr[x].s[1] = 0;
            push_up(x);
            ins(K, cnt);
            ins(1, (sum - cnt * K % MOD + MOD) % MOD);
        }
    }
}

int main() {
    scanf("%d%d%lld", &n, &m, &K);
    for (int i = 0; i < n; i++) {
        int v;
        scanf("%d", &v);
        ins(v, 1);
    }
    while (m--) {
        int x, y;
        scanf("%d%d", &x, &y);
        t_flag(root, x);
        ins(y, 1);
        get1();
        get2();
        printf("%lld\n", tr[root].sum);
    }
    return 0;
}

这就是我的主要思路了,请教各位大佬帮忙看一下哪里出现问题了,万分感谢!

对于第1位解答我的疑问的大佬,我将发送 5050 元红包作为酬谢,谢谢!

2022/12/8 13:56
加载中...