线段树合并TLE60pts求调
查看原帖
线段树合并TLE60pts求调
1026365
yb_10032楼主2025/7/3 14:14
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
const int N = 3e5 + 10;
struct Edge
{
    int to, next;
} e[N << 1];
int head[N], idx;
inline void add(int u, int v)
{
    e[++idx] = {v, head[u]};
    head[u] = idx;
}
int n, m;
int siz[N], d[N];
struct Segment_Tree
{
    struct Node
    {
        int ls, rs, l, r;
        ll sum;
    } tr[N << 6];
    int rt[N], idx;
    Segment_Tree() { idx = 0; }
    inline void update(int &x, int l, int r, int pos, int k)
    {
        if (!x)
        {
            x = ++idx;
            tr[x] = {0, 0, l, r, 0};
        }
        tr[x].sum += k;
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        if (pos <= mid)
            update(tr[x].ls, l, mid, pos, k);
        else
            update(tr[x].rs, mid + 1, r, pos, k);
    }
    inline int merge(int x, int y)
    {
        if (!x || !y)
            return x | y;
        int u = ++idx;
        tr[u] = tr[x];
        tr[u].sum += tr[y].sum;
        if (tr[u].l == tr[u].r)
            return u;
        tr[u].ls = merge(tr[x].ls, tr[y].ls);
        tr[u].rs = merge(tr[x].rs, tr[y].rs);
        return u;
    }
    inline ll query(int x, int l, int r)
    {
        if (!x)
            return 0;
        if (tr[x].r < l && r < tr[x].l)
            return 0;
        if (l <= tr[x].l && tr[x].r <= r)
            return tr[x].sum;
        return query(tr[x].ls, l, r) + query(tr[x].rs, l, r);
    }
} t;
inline void dfs(int u, int fa)
{
    d[u] = d[fa] + 1;
    siz[u] = 1;
    for (int i = head[u]; i; i = e[i].next)
    {
        int v = e[i].to;
        if (v == fa)
            continue;
        dfs(v, u);
        siz[u] += siz[v];
    }
    t.update(t.rt[u], 1, n, d[u], siz[u] - 1);
    if (fa)
        t.rt[fa] = t.merge(t.rt[fa], t.rt[u]);
}
int main(int argc, char const *argv[])
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        add(u, v);
        add(v, u);
    }
    dfs(1, 0);
    for (int i = 1; i <= m; i++)
    {
        int u, k;
        cin >> u >> k;
        cout << t.query(t.rt[u], d[u] + 1, min(d[u] + k, n)) + 1ll * min(d[u] - 1, k) * (siz[u] - 1) << endl;
    }
    cerr << 1000 * clock() / CLOCKS_PER_SEC << " ms" << endl;
    return 0;
}
2025/7/3 14:14
加载中...