求助:调了两天了没调出来,9个点RE,1个点AC
查看原帖
求助:调了两天了没调出来,9个点RE,1个点AC
394488
tzyt楼主2021/9/7 22:04
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400010;
int n, m, r, MOD;
int init_weight[MAXN], weight[MAXN];
vector<int> edge[2 * MAXN];
int dep[MAXN], son_size[MAXN], id[MAXN], chain_top[MAXN];
int dfs_order[MAXN], father[MAXN];
int tag[MAXN * 4], seg_tree[MAXN * 4];
int max_son[MAXN], node_cnt;
bool debug = false;
void add_edge(int x, int y)
{
    edge[x].push_back(y);
    edge[y].push_back(x);
}
namespace seg_t
{
    inline int left_son(int cur)
    {
        return cur * 2 + 1;
    }

    inline int right_son(int cur)
    {
        return cur * 2;
    }

    inline void record_tag(int cur_node, int left, int right, int val)
    {
        tag[cur_node] = tag[cur_node] + val;
        if (debug)
            printf("cur_node %d left %d right %d\n", cur_node, left, right);
        seg_tree[cur_node] = seg_tree[cur_node] + val * (right - left + 1);
        seg_tree[cur_node] %= MOD;
    }

    void push_up_sum(int cur_node)
    {
        seg_tree[cur_node] = seg_tree[left_son(cur_node)] + seg_tree[right_son(cur_node)];
        seg_tree[cur_node] %= MOD;
    }

    void push_down_sum(int cur_node, int left, int right)
    {
        int mid = (left + right) / 2;
        record_tag(left_son(cur_node), left, mid, tag[cur_node]);
        record_tag(right_son(cur_node), mid + 1, right, tag[cur_node]);
        tag[cur_node] = 0;
    }

    void build_seg_tree(int cur_node, int left, int right)
    {
        if (left == right)
        {
            seg_tree[cur_node] = weight[left];
            seg_tree[cur_node] %= MOD;
            return;
        }
        int mid = (left + right) / 2;
        build_seg_tree(left_son(cur_node), left, mid);
        build_seg_tree(right_son(cur_node), mid + 1, right);
        push_up_sum(cur_node);
    }

    int query_sum(int cur_node, int cur_left, int cur_right, int tar_left, int tar_right)
    {
        int ret = 0;
        if (tar_left <= cur_left && cur_right <= tar_right) //目标区间包含当前区间
        {
            seg_tree[cur_node] %= MOD;
            return seg_tree[cur_node];
        }
        int mid = (cur_left + cur_right) / 2;
        if (debug)
            printf("cur in qsum %d\n", cur_node);
        push_down_sum(cur_node, cur_left, cur_right);
        if (tar_left <= mid)
        {
            ret += query_sum(left_son(cur_node), cur_left, mid, tar_left, tar_right);
            ret %= MOD;
        }
        if (tar_right > mid)
        {
            ret += query_sum(right_son(cur_node), mid + 1, cur_right, tar_left, tar_right);
            ret %= MOD;
        }
        return ret % MOD;
    }

    void add_update(int cur_node, int cur_left, int cur_right, int tar_left, int tar_right, int val)
    {
        if (tar_left <= cur_left && cur_right <= tar_right)
        {
            int len = (cur_right - cur_left + 1);
            seg_tree[cur_node] += val * len;
            tag[cur_node] += val;
            return;
        }
        if (debug)
            printf("cur in segupd %d cl %d cr %d tl %d tr %d\n", cur_node, cur_left, cur_right, tar_left, tar_right);
        push_down_sum(cur_node, cur_left, cur_right); //加之前要把tag下放
        int mid = (cur_left + cur_right) / 2;
        if (tar_left <= mid)
        {
            add_update(left_son(cur_node), cur_left, mid, tar_left, tar_right, val);
        }
        if (tar_right > mid)
        {
            add_update(right_son(cur_node), mid + 1, cur_right, tar_left, tar_right, val);
        }
        push_up_sum(cur_node); //加好了要更新根节点信息
    }

}

//--------------------seg_tree end---------------------

int range_query(int node1, int node2)
{
    int ans = 0;
    while (chain_top[node1] != chain_top[node2])
    {
        if (dep[chain_top[node1]] < dep[chain_top[node2]])
        {
            swap(node1, node2); //node1这条链的深度更深
        }
        int ret = seg_t::query_sum(1, 1, n, id[chain_top[node1]], id[node1]);
        ans += ret;
        ans %= MOD;
        node1 = father[chain_top[node1]];
    }
    if (dep[node1] > dep[node2])
    {
        swap(node1, node2);
    }
    int ret = seg_t::query_sum(1, 1, n, id[node1], id[node2]);
    ans += ret;
    return ans % MOD;
}

void range_update(int node1, int node2, int val)
{
    val %= MOD;
    while (chain_top[node1] != chain_top[node2])
    {
        if (dep[chain_top[node1] < dep[chain_top[node2]]])
        {
            swap(node1, node2);
        }
        seg_t::add_update(1, 1, n, id[chain_top[node1]], id[node1], val);
        node1 = father[chain_top[node1]];
        if (debug)
            printf("id node1 %d node1 %d\n", id[node1], node1);
    }
    if (dep[node1] > dep[node2])
    {
        swap(node1, node2);
    }
    seg_t::add_update(1, 1, n, id[node1], id[node2], val);
}

int query_son(int cur)
{
    return seg_t::query_sum(1, 1, n, id[cur], id[cur] + son_size[cur] - 1);
}

void update_son(int cur, int val)
{
    seg_t::add_update(1, 1, n, id[cur], id[cur] + son_size[cur] - 1, val);
}

void find_maxson(int cur, int fa, int depth)
{
    dep[cur] = depth;
    father[cur] = fa;
    son_size[cur] = 1;
    int maxson = -1;
    for (auto nex : edge[cur])
    {
        if (nex == fa)
        {
            continue;
        }
        find_maxson(nex, cur, depth + 1);
        son_size[cur] += son_size[nex];
        if (son_size[nex] > maxson)
        {
            maxson = son_size[nex];
            max_son[cur] = nex;
        }
    }
    if (debug)
        printf("cur %d maxson %d sonsize %d\n", cur, max_son[cur], son_size[cur]);
}

void chain_proc(int cur, int chaintop)
{
    id[cur] = ++node_cnt;
    weight[node_cnt] = init_weight[cur];
    chain_top[cur] = chaintop;
    if (debug)
        printf("cur %d id %d chaintop %d\n", cur, id[cur], chain_top[cur]);
    if (!max_son[cur])
    {
        return;
    }
    chain_proc(max_son[cur], chaintop);
    for (auto nex : edge[cur])
    {
        if (nex == max_son[cur] || nex == father[cur])
        {
            continue;
        }
        chain_proc(nex, nex);
    }
}

int main()
{
    //debug = true;
    scanf("%d%d%d%d", &n, &m, &r, &MOD);
    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &init_weight[i]);
    }
    for (int i = 1; i <= n - 1; i++)
    {
        int from, to;
        scanf("%d%d", &from, &to);
        add_edge(from, to);
    }
    find_maxson(r, 0, 1);
    chain_proc(r, r);
    seg_t::build_seg_tree(1, 1, n);
    while (m--)
    {
        int MODE;
        int x, y, val;
        scanf("%d", &MODE);
        switch (MODE)
        {
        case 1:
        {
            scanf("%d%d%d", &x, &y, &val);
            range_update(x, y, val);
            break;
        }
        case 2:
        {
            scanf("%d%d", &x, &y);
            printf("%d\n", range_query(x, y));
            break;
        }
        case 3:
        {
            scanf("%d%d", &x, &val);
            update_son(x, val);
            break;
        }
        case 4:
        {
            scanf("%d", &x);
            printf("%d\n", query_son(x));
            break;
        }
        }
    }
    system("pause");
}
2021/9/7 22:04
加载中...