85分求助~~
查看原帖
85分求助~~
115888
Jozky楼主2021/8/18 22:00

85分,第14到16点错了,自己看了半天不知道哪里有问题,求助

//#pragma optimize("Ofast")
#include <bits/stdc++.h>
#define MAXN 4000005
#define MAXK 10000007
#define inf 20000007
using namespace std;
typedef long long ll;

int N, K;
struct edge
{
    int v;
    ll w;
    edge(int v= 0, ll w= 0) : v(v), w(w)
    {
    }
};

vector<edge> adj[MAXN];

int vis[MAXN];
ll d[MAXN];
int sz[MAXN];
int rt, cnt;
int t[MAXK], ans= inf;
int dep[MAXN];
void dfs_rt(int u, int fa, int tot)
{
    //++cnt;
    sz[u]= 1;
    int v, w, n= 0;
    for (int k= 0; k < adj[u].size(); k++) {
        v= adj[u][k].v;
        if (v == fa || vis[v])
            continue;
        dfs_rt(v, u, tot);
        sz[u]+= sz[v];
        n= max(n, sz[v]);
    }
    n= max(n, tot - sz[u]);
    if (2 * n <= tot)
        rt= u;
}

void dfs1(int u, int fa) //更新答案
{
    ++cnt;
    if (K >= d[u])
        ans= min(ans, dep[u] + t[K - d[u]]);
    int v, w;
    for (int k= 0; k < adj[u].size(); k++) {
        v= adj[u][k].v;
        w= adj[u][k].w;
        if (v == fa || vis[v])
            continue;

        d[v]= d[u] + w;
        dep[v]= dep[u] + 1;
        dfs1(v, u);
    }
}

void dfs2(int u, int fa, int flag) //更新t
{
    if (K >= d[u]) {
        if (flag == 1)
            t[d[u]]= min(t[d[u]], dep[u]);
        else
            t[d[u]]= inf;
    }

    int v;
    for (int k= 0; k < adj[u].size(); k++) {
        v= adj[u][k].v;
        if (v == fa || vis[v])
            continue;
        dfs2(v, u, flag);
    }
}
void work(int u, int fa, int tot)
{
    dfs_rt(u, fa, tot);
    u= rt;
    vis[u]= 1;
    d[u]= 0;

    t[0]= 1;
    int v, w;
    for (int k= 0; k < adj[u].size(); k++) {
        v= adj[u][k].v;
        w= adj[u][k].w;
        if (vis[v])
            continue;
        cnt= 0;
        d[v]= w; //路径
        dep[v]= 1;
        dfs1(v, u);
        sz[v]= cnt;
        dfs2(v, u, +1);
    }

    dfs2(u, 0, -1);

    for (int k= 0; k < adj[u].size(); k++) {
        v= adj[u][k].v;
        if (vis[v])
            continue;
        work(v, u, sz[v]);
    }
}

int main()
{
    scanf("%d%d", &N, &K);
    memset(t, inf, sizeof(t));
    int u, v, w;
    for (int i= 1; i < N; i++) {
        scanf("%d%d%d", &u, &v, &w);
        u++;
        v++;
        adj[u].push_back(edge(v, w));
        adj[v].push_back(edge(u, w));
    }
    memset(t, inf, sizeof t);

    t[0]= 1;
    work(1, 0, N);

    printf("%d\n", ans >= N ? -1 : ans);
    return 0;
}
2021/8/18 22:00
加载中...