萌新求助,保龄了
查看原帖
萌新求助,保龄了
355559
FutureThx楼主2021/5/5 20:22

求指错或者给个小型的Hack数据

#include <iostream>
#include <vector>
#include <queue>
#include <cmath>
using namespace std;
#define MAX_N 300010
#define int long long
#define mod 998244353
int n;
struct Graph{
    vector<int> next;
    int dep = 0;
}node[MAX_N];
int f[MAX_N][60],pre[MAX_N][60],t;
void bfs(int s){
    queue<int> q;
    t = (int)(log(n) / log(2)) + 1;
    node[s].dep = 1;
    for(int i = 0;i <= t;i++)f[s][i] = s;
    for(int k = 0;k <= 50;k++)
        pre[s][k] = 1;
    q.push(s);
    while(!q.empty()){
        int u = q.front();
        q.pop();
        for(int i = 0;i < node[u].next.size();i++){
            int v = node[u].next[i];
            if(node[v].dep > 0)continue;
            node[v].dep = node[u].dep + 1;
            f[v][0] = u;
            int sum = 1;
            for(int k = 0;k <= 50;k++){
                pre[v][k] = (pre[u][k] + sum) % mod;
                sum %= mod;
                sum *= (node[v].dep - 1) % mod;
            }
            for(int j = 1;j <= t;j++)
                f[v][j] = f[f[v][j-1]][j-1];
            q.push(v);
        }
    }
}
int LCA(int x,int y){
    int t = (int)(log(n) / log(2)) + 1;
    if(node[x].dep > node[y].dep) 
       swap(x,y);
    for(int i = t;i >= 0;i--)
        if(node[f[y][i]].dep >= node[x].dep)
           y = f[y][i];
    if(x == y) return x;
    for(int i = t;i >= 0;i--)
        if(f[x][i] != f[y][i])
           x = f[x][i],y = f[y][i];
    return f[x][0];
}
int query(int a,int b,int k){
    // cout << pre[a][k] << " " << pre[b][k] << " " << pre[LCA(a,b)][k] << endl;
    return ((pre[a][k] + pre[b][k]) % mod + mod - pre[LCA(a,b)][k] - pre[f[LCA(a,b)][0]][k]) % mod;
}
signed main(){
    cin >> n;
    for(int i = 1;i < n;i++){
        int u,v;
        cin >> u >> v;
        node[u].next.push_back(v);
        node[v].next.push_back(u);
    }
    bfs(1);
    int q;
    cin >> q;
    // cout << pre[1][5] << " " << pre[4][5] << " " << pre[1][5] << " " << pre[f[1][0]][5] << endl;
    while(q--){
        int u,v,k;
        cin >> u >> v >> k;
        cout << query(u,v,k) << endl;
    }
    return 0;
}
/*
10
1 2
2 3
3 4
9 5
6 7
6 8
2 9
2 10
6 2
3
3 8 50
2 4 50
9 3 50
0^45
1^45
2^45
3^45
4^45

*/
2021/5/5 20:22
加载中...