打了个暴力,LCA+快速幂,不知为什么WA
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int n,m,idx,t,res,ans;
int f[300010][30],head[300010],dep[300010];
struct node{
int nxt,to;
}edge[1000010];
void add(int u,int v)
{
edge[++idx].nxt=head[u];
edge[idx].to=v;
head[u]=idx;
}
int quickpower(int base,int power,int k)//快速幂赛高
{
res=1;
while(power>0)
{
if(power&1)
{
res*=base;
res%=k;
}
base*=base;
base%=mod;
power>>=1;
}
return res%k;
}
void dfs(int now,int fath)//预处理
{
dep[now]=dep[fath]+1;
f[now][0]=fath;
for(int i=1;i<=t;i++)
{
f[now][i]=f[f[now][i-1]][i-1];
}
for(int i=head[now];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==fath)
continue;
dfs(v,now);
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
for(int i=t;i>=0;i--)
{
if(dep[f[x][i]]>=dep[y])
{
x=f[x][i];
if(x==y)
return x;
}
}
for(int i=t;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
signed main()
{
cin>>n;
t=(int)(log(n)/log(2))+1;
for(int i=1;i<=n-1;i++)
{
int u,v;
scanf("%lld%lld",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
cin>>m;
while(m--)
{
ans=0;
int x,y,k;
scanf("%lld%lld%lld",&x,&y,&k);
int lca=LCA(x,y);//求出LCA,然后求和
for(int i=dep[x]-1;i>=dep[lca]-1;i--)//x=1,lca=1,y=4;
{
ans+=quickpower(i,k,mod);
ans%=mod;
}
for(int i=dep[y]-1;i>dep[lca]-1;i--)//避免多加一次lca的值
{
ans+=quickpower(i,k,mod);
ans%=mod;
}
printf("%lld\n",ans);
}
return 0;
}