rt 代码如下:
#include <cstdio>
#include <cctype>
#include <cmath>
using namespace std;
#define ll long long
inline int read(){
int res = 0, pdf = 0; char ch = getchar();
while(!isdigit(ch)) pdf = ch == '-', ch = getchar();
while(isdigit(ch)) res = (res<<3) + (res<<1) + (ch^48), ch = getchar();
return pdf ? -res : res;
}
inline void Print(ll x) {
if (x < 0) x = -x, putchar('-');
if (x < 10) putchar(x + '0');
else {
Print(x / 10);
putchar(x % 10 + '0');
}
}
inline void Swap(int &x, int &y) { int tmp = x; x = y; y = tmp; }
ll ans;
const int N = 1e5 + 100, M = log2(N) + 5;
int n, m;
int head[N], nex[N << 1], ver[N << 1], val[N << 1], tot = 1;
int vv[N], fa[N][M], depth[N], pre[N];
void Addedge(int x, int y) {
ver[++tot] = y;
nex[tot] = head[x];
head[x] = tot;
}
void ins(int x, int fat) {
fa[x][0] = fat; depth[x] = depth[fat] + 1;
for (int i = 1; i < M; ++i) {
fa[x][i] = fa[fa[x][i - 1]][i - 1];
}
for (int i = head[x]; i; i = nex[i]) {
int y = ver[i];
if (y == fat) continue;
ins(y, x);
}
}
int lca(int x, int y) {
if (depth[x] > depth[y]) Swap(x, y);
while (depth[x] < depth[y]) {
int p = log2(depth[y] - depth[x]);
y = fa[y][p];
}
if (x == y) return x;
for (int i = M - 1; i >= 0; --i) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
void dfs(int x, int fat) {
pre[x] = vv[x];
for (int i = head[x]; i; i = nex[i]) {
int y = ver[i];
if (y == fat) continue;
dfs(y, x);
pre[x] += vv[y];
}
}
int main() {
n = read(); m = read();
int xx, yy;
for (int i = 1; i < n; ++i) {
xx = read(); yy = read();
Addedge(xx, yy); Addedge(yy, xx);
}
ins(1, 0);
for (int i = 1; i <= m; ++i) {
xx = read(); yy = read();
++vv[xx]; ++vv[yy]; vv[lca(xx, yy)] -= 2;
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
if (!pre[i]) ans += (ll)(m);
else if (pre[i] == 1) ++ans;
}
Print(ans);
return 0;
}
救救孩子吧awa