倍增LCA 80pts 求助
查看原帖
倍增LCA 80pts 求助
1101954
Lucas002楼主2025/2/3 22:11

Rt,TLE on #8 #9,Hack 全部 T 飞

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
const int MAXN = 3e5+10;
long long mod = 998244353;
vector <int> G[MAXN];int dep[MAXN];
long long sum[MAXN][60];long long fa[MAXN][30];
long long mark[100];long long refer[100];long long cnt = 0;
long long eachi[MAXN];long long eachj[MAXN];long long eachk[MAXN];
long long quickpow(long long a,long long b) {
	if (b == 0) return 1 % mod;
	if (b == 1) return a % mod;
	long long ans = quickpow(a,b/2);
	ans = ans * ans % mod;
	if (b % 2 == 1) ans = ans * a % mod;
	return ans % mod;
}
void dfs(int u,int f,int d) {
	fa[u][0] = f;dep[u] = d;
	for (int i = 1;i<=20;i++) fa[u][i] = fa[fa[u][i-1]][i-1];
	for (int i = 0;i<G[u].size();i++) {
		if (G[u][i] == f) continue;
		else {
			dfs(G[u][i],u,d+1);
		}
	}
	return;
}
int lca(int u,int v) {
	if (dep[u]<dep[v]) swap(u,v);
	for (int i = 20;i>=0;i--) {
		if (dep[u]-(1<<i)>=dep[v]) u = fa[u][i];
		else continue;
	}
	if (u==v) return u;
	for (int i = 20;i>=0;i--) {
		if (fa[u][i] != fa[v][i]) {
			u = fa[u][i];v = fa[v][i];
		} else continue;
	}
	return fa[u][0];
}
void sumcal(int u,int f,int k) {
	long long now = quickpow(dep[u],k);
	sum[u][k] = (sum[f][k]+now)%mod;
	for (int i = 0;i<G[u].size();i++) {
		if (G[u][i] != f) {
			sumcal(G[u][i],u,k);
		} 
	}
	return;
} 
int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	int n;cin >> n;
	for (int i = 1;i<=n-1;i++) {
		int a,b;cin >> a >> b;
		G[a].push_back(b);
		G[b].push_back(a);
	}
	dfs(1,0,0);
	int m;cin >> m;
	for (int i = 1;i<=m;i++) {
		int a,b,k;cin >> a >> b >> k;
		eachi[i] = a;eachj[i] = b;eachk[i] = k; 
		mark[k]++;
	}
	for (int i = 1;i<=50;i++) {
		if (mark[i]) refer[++cnt] = i;
	}
	for (int i = 1;i<=cnt;i++) {
		sumcal(1,0,refer[i]);
	}
	for (int i = 1;i<=m;i++) {
		int p = lca(eachi[i],eachj[i]);
		long long sumi = sum[eachi[i]][eachk[i]];long long sumj = sum[eachj[i]][eachk[i]];
		long long sumfap = sum[fa[p][0]][eachk[i]];long long valp = sum[p][eachk[i]]-sum[fa[p][0]][eachk[i]];
		long long f = (sumi%mod+sumj%mod)%mod;long long s = (2*sumfap%mod+valp%mod)%mod;
		long long ans = (f%mod-s%mod+mod)%mod;
		cout << ans << "\n";
	}
	return 0;
}
2025/2/3 22:11
加载中...