Rt,TLE on #8 #9,Hack 全部 T 飞
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
const int MAXN = 3e5+10;
long long mod = 998244353;
vector <int> G[MAXN];int dep[MAXN];
long long sum[MAXN][60];long long fa[MAXN][30];
long long mark[100];long long refer[100];long long cnt = 0;
long long eachi[MAXN];long long eachj[MAXN];long long eachk[MAXN];
long long quickpow(long long a,long long b) {
if (b == 0) return 1 % mod;
if (b == 1) return a % mod;
long long ans = quickpow(a,b/2);
ans = ans * ans % mod;
if (b % 2 == 1) ans = ans * a % mod;
return ans % mod;
}
void dfs(int u,int f,int d) {
fa[u][0] = f;dep[u] = d;
for (int i = 1;i<=20;i++) fa[u][i] = fa[fa[u][i-1]][i-1];
for (int i = 0;i<G[u].size();i++) {
if (G[u][i] == f) continue;
else {
dfs(G[u][i],u,d+1);
}
}
return;
}
int lca(int u,int v) {
if (dep[u]<dep[v]) swap(u,v);
for (int i = 20;i>=0;i--) {
if (dep[u]-(1<<i)>=dep[v]) u = fa[u][i];
else continue;
}
if (u==v) return u;
for (int i = 20;i>=0;i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];v = fa[v][i];
} else continue;
}
return fa[u][0];
}
void sumcal(int u,int f,int k) {
long long now = quickpow(dep[u],k);
sum[u][k] = (sum[f][k]+now)%mod;
for (int i = 0;i<G[u].size();i++) {
if (G[u][i] != f) {
sumcal(G[u][i],u,k);
}
}
return;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n;cin >> n;
for (int i = 1;i<=n-1;i++) {
int a,b;cin >> a >> b;
G[a].push_back(b);
G[b].push_back(a);
}
dfs(1,0,0);
int m;cin >> m;
for (int i = 1;i<=m;i++) {
int a,b,k;cin >> a >> b >> k;
eachi[i] = a;eachj[i] = b;eachk[i] = k;
mark[k]++;
}
for (int i = 1;i<=50;i++) {
if (mark[i]) refer[++cnt] = i;
}
for (int i = 1;i<=cnt;i++) {
sumcal(1,0,refer[i]);
}
for (int i = 1;i<=m;i++) {
int p = lca(eachi[i],eachj[i]);
long long sumi = sum[eachi[i]][eachk[i]];long long sumj = sum[eachj[i]][eachk[i]];
long long sumfap = sum[fa[p][0]][eachk[i]];long long valp = sum[p][eachk[i]]-sum[fa[p][0]][eachk[i]];
long long f = (sumi%mod+sumj%mod)%mod;long long s = (2*sumfap%mod+valp%mod)%mod;
long long ans = (f%mod-s%mod+mod)%mod;
cout << ans << "\n";
}
return 0;
}