树剖+线段树 WA 40pts 求调
查看原帖
树剖+线段树 WA 40pts 求调
681591
sintle楼主2025/6/24 13:26
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll N = 600005;
ll n , m , num , cnt = 0 , to[N] , val[N] , h[N] , nxt[N] , fa[N] , dep[N] , len[N] , dfn[N] , pos[N];
ll siz[N] , top[N] , son[N] , tag[N] , ans = 0 , max1 = 0 , max2 = 0 , u1 , v1;

long long read()
{
    int f = 1 , res = 0; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') res = res * 10 + ch - '0' , ch = getchar();
    return f * res;
}

struct seg
{
    ll l , r , maxn , tag;
};
seg s[N << 2];

struct node
{
    ll l , r;
};
node q[N << 1]; ll tot = 0;

void add(ll u , ll v , ll t)
{
    to[++cnt] = v; val[cnt] = t;
    nxt[cnt] = h[u]; h[u] = cnt;
    to[++cnt] = u; val[cnt] = t;
    nxt[cnt] = h[v]; h[v] = cnt;
}

void dfs1(ll u , ll f)
{
    fa[u] = f; siz[u] = 1; dep[u] = dep[f] + 1;
    for(int i = h[u] ; i ; i = nxt[i])
    {
        int v = to[i] , t = val[i];
        if(v == f) continue;
        len[v] = len[u] + t;
        tag[v] = t;
        dfs1(v , u);
        siz[u] += siz[v];
        if(siz[v] > siz[son[u]]) son[u] = v;
    }
}

void dfs2(ll u)
{
    pos[dfn[u] = ++num] = u;
    if(!son[u]) return;
    top[son[u]] = top[u];
    dfs2(son[u]);
    for(int i = h[u] ; i ; i = nxt[i])
    {
        int v = to[i] , t = val[i];
        if(v == fa[u] || v == son[u]) continue;
        top[v] = v; dfs2(v);
    }
}

void pushdown(ll pos)
{
    if(s[pos].tag)
    {
        s[pos << 1].maxn = max(s[pos << 1].maxn , s[pos].tag);
        s[pos << 1 | 1].maxn = max(s[pos << 1 | 1].maxn , s[pos].tag);
        s[pos << 1].tag = max(s[pos << 1].tag , s[pos].tag);
        s[pos << 1 | 1].tag = max(s[pos << 1 | 1].tag , s[pos].tag);
        s[pos].tag = 0;
    }
}

ll query(ll u , ll id)
{
    if(s[u].l == s[u].r) return s[u].maxn;
    pushdown(u); ll mid = s[u].l + s[u].r >> 1;
    if(id <= mid) return query(u << 1 , id);
    else return query(u << 1 | 1 , id);
}

ll LCA(ll u , ll v)
{
    while(top[u] != top[v])
    {
        if(dep[top[u]] < dep[top[v]]) swap(u , v);
        u = fa[top[u]];
    }
    return dep[u] < dep[v] ? u : v;
}

void build(ll pos , ll l , ll r)
{
    s[pos].l = l; s[pos].r = r; s[pos].maxn = 0; s[pos].tag = 0;
    if(l == r) return;
    ll mid = l + r >> 1;
    build(pos << 1 , l , mid);
    build(pos << 1 | 1 , mid + 1 , r);
}

void pushup(ll pos) {s[pos].maxn = max(s[pos << 1].maxn , s[pos << 1 | 1].maxn);}

void supdate(ll pos , ll l , ll r , ll num)
{
    if(s[pos].l > r || s[pos].r < l) return;
    if(s[pos].l >= l && s[pos].r <= r) {s[pos].maxn = max(s[pos].maxn , num) , s[pos].tag = max(s[pos].tag , num); return;}
    pushdown(pos);
    supdate(pos << 1 , l , r , num);
    supdate(pos << 1 | 1 , l , r , num);
    pushup(pos);
}

void update(ll u , ll v , ll num)
{
    while(top[u] != top[v])
    {
        if(dep[top[u]] < dep[top[v]]) swap(u , v);
        q[++tot].l = dfn[top[u]] , q[tot].r = dfn[u];
        u = fa[top[u]];
    }
    if(dep[u] > dep[v]) swap(u , v);
    q[++tot].l = dfn[u] + 1 , q[tot].r = dfn[v];
    sort(q + 1 , q + tot + 1 , [](node a , node b) {return a.l < b.l;});
    if(q[1].l > 1) supdate(1 , 1 , q[1].l - 1 , num);
    for(ll i = 1 ; i < tot ; i++) supdate(1 , q[i].r + 1 , q[i + 1].l - 1 , num);
    if(q[tot].r < n) supdate(1 , q[tot].r + 1 , n , num);
    tot = 0;
}

ll find_ans(ll u , ll v)
{
    if(u == v) return 0;
    ll ans = 0x3f3f3f3f3f3f3f3f;
    if(dep[u] < dep[v]) swap(u , v);
    while(dep[u] > dep[v]) {ans = min(ans , max(max1 - tag[u] , query(1 , dfn[u]))); u = fa[top[u]];}
    while(u != v)
    {
        if(dep[u] > dep[v]) {ans = min(ans , max(max1 - tag[u] , query(1 , dfn[u]))); u = fa[top[u]];}
        else {ans = min(ans , max(max1 - tag[v] , query(1 , dfn[v]))); v = fa[top[v]];}
    }
    return ans;
}

signed main()
{
    n = read(); m = read();
    for(ll i = 1 , u , v , t ; i < n ; i++) u = read() , v = read() , t = read() , add(u , v , t);
    dfs1(1 , 0); top[1] = 1; dfs2(1); build(1 , 1 , n);
    while(m--)
    {
        ll u , v; u = read(); v = read();
        ll num = len[u] + len[v] - 2 * len[LCA(u , v)];
        update(u , v , num);
        if(num > max1) max1 = num , u1 = u , v1 = v;
    }
    printf("%lld\n", find_ans(u1 , v1));
    return 0;
}
2025/6/24 13:26
加载中...