mxqz
查看原帖
mxqz
384214
esquigybcu楼主2022/1/20 16:58

#11 和 #15 T 掉

#include <stdio.h>
#include <string.h>
#include <algorithm>

typedef long long ll;
const int N = 3e5 + 5, K = 50 + 5, MOD = 998244353;

inline constexpr int log2(int n) {return 31 - __builtin_clz(n);}

inline ll qpow(ll x, int k)
{
    if (k == 0) return 1;
    ll t = qpow(x, k >> 1); t = t * t % MOD;
    if (k & 1) t = x * t % MOD;
    return t;
}

int pre[N][K]; // pre[x][y] = \sum_{i=0}^x i^y

inline void getpre()
{
    // pre[x][y] = pre[x - 1][y] + x^y

    for (int y = 0; y <= 50; y++)
        pre[0][y] = 0;
    for (int x = 0; x <= 3e5; x++)
        for (int y = 0; y <= 50; y++)
            pre[x][y] = (pre[x - 1][y] + qpow(x, y)) % MOD;
}

struct edge
{
    int u, v, next;
}
e[N << 1]; int cnt, head[N];

inline void add_edge(int u, int v)
{
    e[cnt].u = u, e[cnt].v = v, e[cnt].next = head[u];
    head[u] = cnt++;
}

int depth[N], top[N][log2(N)];

inline void dfs(int u, int f)
{
    if (f == 0) depth[u] = 0;
    else depth[u] = depth[f] + 1;

    top[u][0] = f;
    for (int i = 1; i <= log2(depth[u]); i++)
        top[u][i] = top[top[u][i - 1]][i - 1];

    for (int i = head[u]; ~i; i = e[i].next)
        if (e[i].v != f)
            dfs(e[i].v, u);
}

inline int lca(int u, int v)
{
    if (depth[u] < depth[v])
        std::swap(u, v);
    while (depth[u] > depth[v])
        u = top[u][log2(depth[u] - depth[v])];
    if (u == v)
        return u;
    for (int i = log2(depth[u]); i >= 0; i--)
        if (top[u][i] != top[v][i])
            u = top[u][i], v = top[v][i];
    return top[u][0];
}

int main()
{
    memset(head, -1, sizeof head);
    getpre();

    int n; scanf("%d", &n);
    for (int i = 0; i < n - 1; i++)
    {
        int u, v; scanf("%d %d", &u, &v);
        add_edge(u, v), add_edge(v, u);
    }
    dfs(1, 0);

    int q; scanf("%d", &q);
    while (q--)
    {
        int i, j, k;
        scanf("%d %d %d", &i, &j, &k);
        int l = lca(i, j);
        ll ans = pre[depth[i]][k] + pre[depth[j]][k] - pre[depth[l]][k];
        if (l != 1) ans -= pre[depth[l] - 1][k];
        printf("%d\n", (int)(ans % MOD + MOD) % MOD);
    }
    return 0;
}
2022/1/20 16:58
加载中...