刚学OI,97分求助
查看原帖
刚学OI,97分求助
19951
Reaepita楼主2020/3/31 10:18

subtask2 第一个点 T 了

#include<bits/stdc++.h>
using namespace std;
const long long inf=1e18;
namespace T3
{   
    const int maxn=1e5+10;
    struct edge{int v,nxt;long long w;}e[maxn<<1];
    int head[maxn],ecnt=1;
    void add(int u,int v,long long w)
    {
        e[++ecnt]=(edge){v,head[u],w},head[u]=ecnt;
        e[++ecnt]=(edge){u,head[v],w},head[v]=ecnt;
    }
    int rmq[maxn<<1][22],dfn[maxn],idx=0,Log[maxn<<1];
    int dep[maxn];
    long long dis[maxn];
    void dfs(int u,int pre)
    {
        rmq[++idx][0]=u,dfn[u]=idx;
        for(int i=head[u];~i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(v==pre)continue;
            dis[v]=dis[u]+e[i].w;
            dep[v]=dep[u]+1;
            dfs(v,u),rmq[++idx][0]=u;
        }
    }
    int cmp(const int &a,const int &b){return dep[a]<dep[b]?a:b;}
    void pre()
    {
        memset(head,-1,sizeof(head)),ecnt=1;
    }
    void init()
    {
        dfs(1,0),Log[0]=-1;
        for(int i=1;i<=idx;i++)Log[i]=Log[i>>1]+1;
        for(int j=1;j<=21;j++)
        for(int i=1;i+(1<<j)-1<=idx;i++)rmq[i][j]=cmp(rmq[i][j-1],rmq[i+(1<<j-1)][j-1]);
    }
    int lca(int u,int v)
    {
        u=dfn[u],v=dfn[v];
        if(u>v)swap(u,v);
        int len=Log[v-u+1];
        return cmp(rmq[u][len],rmq[v-(1<<len)+1][len]);
    }
    long long Dis(int u,int v){return dis[u]+dis[v]-2*dis[lca(u,v)];}
}
namespace T2
{
    const int maxn=1e5+10;
    struct edge{int v,nxt;long long w;}e[maxn<<1];
    int head[maxn],ecnt=1;
    void add(int u,int v,long long w)
    {
        e[++ecnt]=(edge){v,head[u],w},head[u]=ecnt;
        e[++ecnt]=(edge){u,head[v],w},head[v]=ecnt;
    }
    int rmq[maxn<<1][22],dfn[maxn],idx=0,Log[maxn<<1];
    int dep[maxn];
    long long dis[maxn];
    void dfs(int u,int pre)
    {
        rmq[++idx][0]=u,dfn[u]=idx;
        for(int i=head[u];~i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(v==pre)continue;
            dis[v]=dis[u]+e[i].w;
            dep[v]=dep[u]+1;
            dfs(v,u),rmq[++idx][0]=u;
        }
    }
    int cmp(const int &a,const int &b){return dep[a]<dep[b]?a:b;}
    void pre()
    {
        memset(head,-1,sizeof(head)),ecnt=1;
    }
    int lca(int u,int v)
    {
        u=dfn[u],v=dfn[v];
        if(u>v)swap(u,v);
        int len=Log[v-u+1];
        return cmp(rmq[u][len],rmq[v-(1<<len)+1][len]);
    }
    long long val[maxn];
    int col[maxn];
    bool cmps(const int &a,const int &b){return dfn[a]<dfn[b];}
    vector<int>G[maxn];
    int vis[maxn],tim;
    void init()
    {
        dfs(1,0),Log[0]=-1;
        memset(col,-1,sizeof(col));
        for(int i=1;i<=idx;i++)Log[i]=Log[i>>1]+1;
        for(int j=1;j<=21;j++)
        for(int i=1;i+(1<<j)-1<=idx;i++)rmq[i][j]=cmp(rmq[i][j-1],rmq[i+(1<<j-1)][j-1]);
    }
    void addedge(int u,int v)
    {
        if(vis[u]!=tim)vis[u]=tim,G[u].clear();
        G[u].push_back(v);
    }
    long long calc(int u,int v){if(!u||!v)return 0;return dis[u]+dis[v]+val[u]+val[v]+T3::Dis(u,v);}
    int p[maxn];
    long long ans=0;
    struct data
    {
        int u,v;
        long long dis;
        data(){u=v=dis=0;}
        data(int U,int V){u=U,v=V,dis=calc(u,v);}
        data(int U,int V,long long D){u=U,v=V,dis=D;}
        friend bool operator < (const data &a,const data &b){return a.dis<b.dis;}
        friend data operator + (const data &a,const data &b)
        {
            if(!a.u)return b;
            if(!b.u)return a;
            data ret=max(a,b);
            ret=max(ret,max(data(a.u,b.v),data(a.v,b.u)));
            ret=max(ret,max(data(a.v,b.v),data(a.u,b.u)));
            return ret;
        }
    }dp[maxn][2];
    void getans(int u,int v)
    {
        ans=max(ans,calc(dp[u][0].u,dp[v][1].u)-2*dis[u]);
        ans=max(ans,calc(dp[u][0].u,dp[v][1].v)-2*dis[u]);
        ans=max(ans,calc(dp[u][0].v,dp[v][1].u)-2*dis[u]);
        ans=max(ans,calc(dp[u][0].v,dp[v][1].v)-2*dis[u]);
        ans=max(ans,calc(dp[u][1].u,dp[v][0].u)-2*dis[u]);
        ans=max(ans,calc(dp[u][1].u,dp[v][0].v)-2*dis[u]);
        ans=max(ans,calc(dp[u][1].v,dp[v][0].u)-2*dis[u]);
        ans=max(ans,calc(dp[u][1].v,dp[v][0].v)-2*dis[u]);
    }
    void calc(int u)
    {
        dp[u][0]=dp[u][1]=data();
        if(col[u]==0)dp[u][0]=data(u,u,0);
        if(col[u]==1)dp[u][1]=data(u,u,0);
        for(auto v:G[u])
        {
            calc(v);
            getans(u,v);
            dp[u][0]=dp[u][0]+dp[v][0];
            dp[u][1]=dp[u][1]+dp[v][1];
        }
    }
    int stk[maxn],top=0;
    long long solve(int *lst,const int &k)
    {
        tim++,ans=0;
        for(int i=1;i<=k;i++)p[i]=lst[i];
        sort(p+1,p+1+k,cmps);
        if(p[1]!=1)stk[top=1]=1;
        for(int i=1;i<=k;i++)
        {
            int u=p[i];
            if(top==0){stk[++top]=u;continue;}
            int f=lca(stk[top],u);
            while(top>1&&dep[f]<dep[stk[top-1]])addedge(stk[top-1],stk[top]),top--;
            if(dep[f]<dep[stk[top]])addedge(f,stk[top]),top--;
            if(!top||stk[top]!=f)stk[++top]=f;
            stk[++top]=u;
        }
        if(top)while(--top)addedge(stk[top],stk[top+1]);
        calc(1);
        return ans;
    }
}
long long ans=0;
int n;
namespace T1
{
    const int maxn=2e5+10;
    vector<int>G[maxn];
    vector<long long>W[maxn];
    struct edge{int v,nxt;long long w;}e[maxn<<1];
    int head[maxn],ecnt=1;
    int all;
    void adde(int u,int v,long long w)
    {
        G[u].push_back(v),W[u].push_back(w);
        G[v].push_back(u),W[v].push_back(w);
    }
    void add(int u,int v,long long w)
    {
        e[++ecnt]=(edge){v,head[u],w},head[u]=ecnt;
        e[++ecnt]=(edge){u,head[v],w},head[v]=ecnt;
    }
    void rebuild(int u,int pre)
    {
        int lst=-1;
        for(int i=0;i<G[u].size();i++)
        {
            int v=G[u][i];
            if(v==pre)continue;
            rebuild(v,u);
            if(lst==-1)add(u,v,W[u][i]),lst=u;
            else 
            {
                int now=++all;
                add(lst,now,0);
                add(now,v,W[u][i]),lst=now;
            }
        }
    }
    void init()
    {
        memset(head,-1,sizeof(head)),ecnt=1;
        rebuild(1,0);
    }
    int vis[maxn],siz[maxn<<2];
    int ns,mi,id;
    void getsiz(int u,int pre)
    {
        siz[u]=1;
        for(int i=head[u];~i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(vis[i>>1]||v==pre)continue;
            getsiz(v,u);
            siz[u]+=siz[v];
        }
    }
    void getedge(int u,int pre)
    {
        for(int i=head[u];~i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(vis[i>>1]||v==pre)continue;
            getedge(v,u);
            int now=max(siz[v],ns-siz[v]);
            if(now<mi)id=i,mi=now;
        }
    }
    int col[maxn],stk[maxn],top;
    long long len[maxn];
    void Get(int u,int pre,long long dist,int c)
    {
        if(u<=n)col[++top]=c,stk[top]=u,len[top]=dist;
        for(int i=head[u];~i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(vis[i>>1]||v==pre)continue;
            Get(v,u,dist+e[i].w,c);
        }
    }
    void solve(int u)
    {
        getsiz(u,0);
        ns=siz[u],mi=1e9,id=-1;
        getedge(u,0);
        if(id==-1)return ;
        int rt=id>>1;vis[rt]=1;
        top=0;
        int L=e[id].v,R=e[id^1].v; 
        Get(L,R,0,0),Get(R,L,0,1);
        for(int i=1;i<=top;i++)
        {
            T2::val[stk[i]]+=len[i];
            T2::col[stk[i]]=col[i];
        }
        ans=max(ans,T2::solve(stk,top)+e[id].w);
        for(int i=1;i<=top;i++)
        {
            T2::val[stk[i]]-=len[i];
            T2::col[stk[i]]=-1;
        }
        solve(L),solve(R);
    }
}
int main()
{
    scanf("%d",&n);
    T1::all=n;
    for(int i=1;i<n;i++)
    {
        int u,v;long long w;
        scanf("%d%d%lld",&u,&v,&w);
        T1::adde(u,v,w);
    }
    T1::init();
    T2::pre();
    for(int i=1;i<n;i++)
    {
        int u,v;long long w;
        scanf("%d%d%lld",&u,&v,&w);
        T2::add(u,v,w);
    }
    T2::init();
    T3::pre();
    for(int i=1;i<n;i++)
    {
        int u,v;long long w;
        scanf("%d%d%lld",&u,&v,&w);
        T3::add(u,v,w);
    }
    T3::init();
    T1::solve(1);
    printf("%lld\n",ans);
}
2020/3/31 10:18
加载中...