代码如下
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
const int Maxn = 5e5+100;
struct Edge
{
int v, nxt;
}edge[Maxn << 1];
int head[Maxn], cnt;
struct Node
{
int fa;
int depth = -1;
int hson;
int top = 0;
int size;
int intime;
int outtime;
}node[Maxn];
int dfn;int timestamp;
int Rank[Maxn];
void Add_edge(int u, int v)
{
edge[++cnt].v = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
void Tree_Build(int u, int dep)
{
node[u].depth = dep;
node[u].hson = -1;
node[u].size = 1;
node[u].intime = ++timestamp;
for(int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].v;
if(node[v].depth != -1) continue;
Tree_Build(v, dep + 1);
node[v].fa = u;
node[u].size += node[v].size;
if(node[u].hson == -1 || node[v].size > node[node[u].hson].size)
{
node[u].hson = v;
}
}
node[u].outtime = ++timestamp;
}
void Tree_Decomposition(int u, int top)
{
node[u].top = top;
node[u].intime = ++dfn;
Rank[dfn] = u;
if(node[u].hson == -1) return;
Tree_Decomposition(node[u].hson, top);
for(int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].v;
if(v == node[u].hson || node[v].top) continue;
Tree_Decomposition(v, v);
}
}
inline int Lca(int x, int y)
{
while (node[x].top != node[y].top)
{
if(node[node[x].top].depth > node[node[y].top].depth)
swap(x, y);
y = node[node[y].top].fa;
}
return node[x].depth < node[y].depth ? x : y;
}
inline int distance(int x, int y) {
int l = Lca(x, y);
return node[x].depth + node[y].depth - 2 * node[l].depth;
}
int n, q;
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> q;
for(int i = 1; i < n; ++i)
{
int u, v;
cin >> u >> v;
Add_edge(u, v);
Add_edge(v, u);
}
int root = 1;
Tree_Build(root, 0);
Tree_Decomposition(root, root);
while(q--)
{
int a, b, c;
cin >> a >> b >> c;
if(a == c && b == c)
{
cout << n << '\n';
continue;
}
int d_ac = distance(a, c);
int d_cb = distance(c, b);
int d_ab = distance(a, b);
if(d_ac + d_cb != d_ab)
{
cout << 0 << '\n';
continue;
}
int x = -1, y = -1;
if(a != c)
{
int target = d_ac - 1;
for(int i = head[c]; i; i = edge[i].nxt)
{
int v = edge[i].v;
if(distance(a, v) == target)
{
x = v;
break;
}
}
}
if(b != c)
{
int target = d_cb - 1;
for(int i = head[c]; i; i = edge[i].nxt)
{
int v = edge[i].v;
if(distance(b, v) == target)
{
y = v;
break;
}
}
}
bool valid = true;
if(a != c && b != c)
{
if(x == y)
valid = false;
}
if(!valid)
{
cout << 0 << '\n';
continue;
}
int size_x = 0, size_y = 0;
if(x != -1)
{
if(node[c].fa == x)
size_x = n - node[c].size;
else
size_x = node[x].size;
}
if(y != -1)
{
if(node[c].fa == y)
size_y = n - node[c].size;
else
size_y = node[y].size;
}
cout << n - size_x - size_y << '\n';
}
return 0;
}