求求了,救救孩子吧
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define ll long long
#define inf 0x3f
using namespace std;
const int maxn = 100005;
int n, k, tot, head[maxn], w[maxn << 1], v[maxn << 1], nxt[maxn << 1], dis[maxn], pre[maxn], f[maxn], vis[maxn];
queue<int> q;
void add(int x, int y) {
v[++tot] = y;
w[tot] = 1;
nxt[tot] = head[x];
head[x] = tot;
}
int bfs(int s) {
memset(dis, inf, sizeof(dis));
q.push(s);
dis[s] = pre[s] = 0;
while (!q.empty()) {
int x = q.front();
q.pop();
for (int i = head[x]; i; i = nxt[i]) {
if (dis[v[i]] >= dis[x] + w[i]) {
dis[v[i]] = dis[x] + w[i];
pre[v[i]] = i;
q.push(v[i]);
}
}
}
int point, ans = -inf;
for (int i = 1; i <= n; i++) {
if (ans < dis[i]) {
ans = dis[i];
point = i;
}
}
return point;
}
namespace k1 {
int diameter() {
return dis[bfs(bfs(1))];
}
void solve() {
int D = diameter();
cout << 2 * (n - 1) - D + 1 << endl;
}
}
namespace k2 {
int point = bfs(bfs(1)), ans = -inf;
int diameter() {
return dis[point];
}
void change() {
for (; pre[point]; point = v[pre[point] ^ 1]) {
w[pre[point]] = w[pre[point] ^ 1] = -1;
}
}
void dp(int u) {
vis[u] = 1;
for (int i = head[u]; i; i = nxt[i]) {
if (!vis[v[i]]) {
dp(v[i]);
ans = max(ans, f[u] + f[v[i]] + w[i]);
f[u] = max(f[u], f[v[i]] + w[i]);
}
}
}
void solve() {
int D = diameter();
change();
dp(1);
cout << 2 * n - D - ans << endl;
}
}
int main() {
cin >> n >> k;
tot = 1;
for (int i = 1; i < n; i++) {
int a, b;
cin >> a >> b;
add(a, b), add(b, a);
}
if (k == 1) {
k1::solve();
} else {
k2::solve();
}
system("pause");
return 0;
}