思路是预处理一下 k 次方前缀和,然后直接树上差分,但是不太懂为啥挂了很多点,而且AC的还是最后3个点,非常抽象 。
#include <bits/stdc++.h>
#define int long long
#define F(i,j,k) for(int i=j;i<=k;i++)
#define add(x,y) g[x].push_back(y)
using namespace std;
const int N=3e5+5;
int dep[N],f[N][55],fa[N][25],n,m; vector<int>g[N];
struct FastIO
{
static const int S = 1e7;
int wpos;
char wbuf[S];
FastIO() : wpos(0) {}
inline int xchar()
{
static char buf[S];
static int len = 0, pos = 0;
if (pos == len)
pos = 0, len = fread(buf, 1, S, stdin);
return buf[pos++];
}
inline int xuint()
{
int c = xchar(), x = 0;
while (c <= 32)
c = xchar();
for (; '0' <= c && c <= '9'; c = xchar())
x = x * 10 + c - '0';
return x;
}
inline void wchar(int x)
{
if (wpos == S)
fwrite(wbuf, 1, S, stdout), wpos = 0;
wbuf[wpos++] = x;
}
inline void wint(int x)
{
//if (x < 0) wchar('-'), x = -x;
char s[24];
int n = 0;
while (x || !n)
s[n++] = '0' + x % 10, x /= 10;
while (n--)
wchar(s[n]);
wchar('\n');
}
~FastIO()
{
if (wpos)
fwrite(wbuf, 1, wpos, stdout), wpos = 0;
}
} io;
const int mod=998244353;
inline void dfs(int u,int fath){
dep[u]=dep[fath]+1,fa[u][0]=fath;
for(auto v:g[u]) if(v!=fath) dfs(v,u);
}
inline void init(){
F(i,1,50) f[i][0]=1;
F(i,1,n) F(j,1,50) f[i][j]=1ll*f[i][j-1]*i%mod;
F(i,1,n) F(j,1,50) f[i][j]=(f[i][j]+f[i-1][j])%mod;
F(i,1,20) F(j,1,n) fa[j][i]=fa[fa[j][i-1]][i-1];
}
inline int lca(int x,int y){
for(int i=20;i>=0;i--) if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
for(int i=20;i>=0;i--) if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
signed main(){
n=io.xuint();
F(i,2,n){
int u=io.xuint(),v=io.xuint();
add(u,v),add(v,u);
}
dep[0]=-1;
dfs(1,0);
init();
m=io.xuint();
while(m--){
int u=io.xuint(),v=io.xuint(),k=io.xuint();
if(dep[u]<dep[v])swap(u,v);
int l=lca(u,v);
int ans=f[dep[u]][k]+f[dep[v]][k]-f[dep[l]][k]-f[max(0ll,dep[l]-1)][k];
ans=(ans%mod+mod)%mod;
io.wint(ans);
}
return 0;
}