萌新求助线段树模板
查看原帖
萌新求助线段树模板
235926
1kri楼主2020/6/1 13:03

WA+MLE,想知道为什么WA。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
int n,m,x,y,z,u[200005],v[200005],first[100005],nxt[200005];
int fa[100005][25],depth[100005],root[100005];
int ans[100005];
vector <int>c[100005];
struct node{
    int l,r,ans,mx;
    node(){
        l=r=ans=mx=0;
        return;
    }
}tree[3000005];
int cnt;
inline void pushup(int now){
    if (tree[tree[now].l].mx>=tree[tree[now].r].mx)tree[now].mx=tree[tree[now].l].mx,tree[now].ans=tree[tree[now].l].ans;
    else tree[now].mx=tree[tree[now].r].mx,tree[now].ans=tree[tree[now].r].ans;
    return;
}
inline void update(int &now,int nowl,int nowr,int pos,int val){
    tree[++cnt]=tree[now];
    now=cnt;
    if (nowl==nowr){
        tree[now].ans=nowl;
        tree[now].mx+=val;
        return;
    }
    int mid=(nowl+nowr)/2;
    if (pos<=mid)update(tree[now].l,nowl,mid,pos,val);
    else update(tree[now].r,mid+1,nowr,pos,val);
    pushup(now);
    return;
}
inline int merge(int x,int y,int nowl,int nowr){
    if (x==0)return y;
    if (y==0)return x;
    if (nowl==nowr){
        tree[x].mx+=tree[y].mx;
        tree[x].ans=nowl;
        return x;
    }
    int mid=(nowl+nowr)/2;
    merge(tree[x].l,tree[y].l,nowl,mid);
    merge(tree[x].r,tree[y].r,mid+1,nowr);
    pushup(x);
    return x;
}
inline void dfs1(int now,int f,int d){
    fa[now][0]=f,depth[now]=d;
    for (int i=1;i<=20;i++)fa[now][i]=fa[fa[now][i-1]][i-1];
    for (int i=first[now];i;i=nxt[i])
        if (v[i]!=f)dfs1(v[i],now,d+1);
    return;
}
inline void dfs2(int now,int f){
    for (int i=0,len=c[now].size();i<len;i++){
        if (c[now][i]>0)update(root[now],1,n,c[now][i],1);
        else update(root[now],1,n,c[now][i],-2);
    }
    for (int i=first[now];i;i=nxt[i]){
        if (v[i]==f)continue;
        dfs2(v[i],now);
        merge(root[now],root[v[i]],1,n);
    }
    if (tree[root[now]].mx>0)ans[now]=tree[root[now]].ans;
    else ans[now]=0;
    return;
}
inline int lca(int a,int b){
    if (depth[a]>depth[b])swap(a,b);
    for (int i=20;i>=0;i--)
        if (depth[fa[b][i]]>=depth[a])b=fa[b][i];
    if (a==b)return a;
    for (int i=20;i>=0;i--)
        if (fa[a][i]!=fa[b][i])a=fa[a][i],b=fa[b][i];
    return fa[a][0];
}
int main(){
    cin>>n>>m;
    for (int i=1;i<n;i++){
        scanf("%d%d",&u[i],&v[i]);
        nxt[i]=first[u[i]],first[u[i]]=i;
        u[i+n]=v[i],v[i+n]=u[i];
        nxt[i+n]=first[u[i+n]],first[u[i+n]]=i+n;
    }
    dfs1(1,0,1);
    for (int i=1;i<=m;i++){
        scanf("%d%d%d",&x,&y,&z);
        c[x].push_back(z);
        c[y].push_back(z);
        c[fa[lca(x,y)][0]].push_back(-z);
    }
    dfs2(1,0);
    for (int i=1;i<=n;i++)printf("%d\n",ans[i]);
    return 0;
}
2020/6/1 13:03
加载中...