全WA,在线等,悬114514关
查看原帖
全WA,在线等,悬114514关
778382
wch666楼主2025/2/7 09:27
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 5e5 + 5;
const int mod = 998244353;
int n, m, s;
int dep[maxn], f[maxn][50];
int sum[maxn][50];
vector<int> vec[maxn];
void add(int u, int v)
{
	vec[u].push_back(v);
}
int fast(int base,int power)
{
	int result=1;
	while(power>0)
	{
		if(power&1)
			result*=base%mod;
		power/=2;
		base*=base%mod;
	}
	return result%mod;
}
void dfs(int u, int fa)
{
	dep[u] = dep[fa] + 1;
	f[u][0] = fa;
	for(int i = 1; i <= 50; i++)
		sum[u][i] = (sum[fa][i]%mod + fast(dep[u], i)%mod+mod) % mod;
	for(auto v : vec[u])
		if(v != fa)
			dfs(v, u);
}
int lca(int a, int b)
{
	if(a == b)
		return a;
	if(f[a][0] == f[b][0])
		return f[a][0];
	if(dep[a] < dep[b])
		swap(a, b);
	while(dep[a] > dep[b])
		a = f[a][__lg(dep[a] - dep[b])];
	if(a == b)
		return a;
	if(f[a][0] == f[b][0])
		return f[a][0];
	for(int j = 21; j >= 0; j--)
	{
		if(f[a][j] != f[b][j])
		{
			a = f[a][j];
			b = f[b][j];
		}
	}
	return f[a][0];
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
	cin>>n;
	for(int i = 1; i < n; i++)
	{
		int u, v;
		cin>>u>>v;
		add(u, v);
		add(v, u);
	}
	dep[1] = -1;
	dfs(1, 1);
	for(int j = 1; j <= 21; j++)
		for(int i = 1; i <= n; i++)
			f[i][j] = f[f[i][j - 1]][j - 1];
	cin>>m;
	while(m--)
	{
		int a, b, k;
		cin>>a>>b>>k;
		int p = lca(a, b);
		int sum1 = (sum[a][k] + sum[b][k] + mod) % mod;
		int sum2 = (sum[p][k] + sum[f[p][0]][k] + mod) % mod;
		cout<<(sum1 - sum2 + mod) % mod<<endl; 
	}
	return 0;
}
2025/2/7 09:27
加载中...