为啥 是wa,求赐教。
查看原帖
为啥 是wa,求赐教。
331161
Orang楼主2020/7/5 00:34
//linux下代码里不能有index,否则编译会报错.
#include<cstdio>
#include<ctype.h>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 40100
#define M 100100
using namespace std;
int head[N],fa[N][20],dep[N],b[N<<1],st[N],ed[N],num,lg[N],used[N],a[N],tmp[N],pos[N],cnt[N],ans[M],res,k;

struct Node{
    int to,next;
    Node(){}
    Node(int to,int next):to(to),next(next){}
}edge[N<<1];

struct seq{
    int l,r,lca,id;
    seq(){}
    seq(int l,int r,int lca,int id):l(l),r(r),lca(lca),id(id){}
    bool operator < (const seq &x)const{
        return pos[l]^pos[x.l]?l<x.l:(pos[l]&1?r<x.r:r>x.r);
    }
}q[M];

inline int read(){
	int x=0,f=1;char ch=getchar();
	while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
	while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
	return x*f;
}

void init(int n){
    memset(head,-1,sizeof(head));
    lg[0]=-1;
    for(int i=1;i<=n;i++)
        lg[i]=lg[i>>1]+1;
}

void discretization(int n){
    sort(tmp+1,tmp+1+n);
    int len=unique(tmp+1,tmp+1+n)-tmp-1;
    for(int i=1;i<=n;i++)
        a[i]=lower_bound(tmp+1,tmp+1+len,a[i])-tmp;
}


void addEdge(int from,int to){
    edge[k]=Node(to,head[from]);
    head[from]=k++;
    edge[k]=Node(from,head[to]);
    head[to]=k++;
}

void dfs(int u){
    st[u]=++num;
    b[num]=u;
    for(int i=1;i<=lg[dep[u]];i++)
        fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v!=fa[u][0]){
            fa[v][0]=u;
            dep[v]=dep[u]+1;
            dfs(v);
        }
    }
    ed[u]=++num;
    b[num]=u;
}

int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    while(dep[x]>dep[y])
        x=fa[x][lg[dep[x]-dep[y]]];
    if(x == y) return x;
    for(int i=lg[dep[x]];i>=0;i--){
        if(fa[x][i]!=fa[y][i]){
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    return fa[x][0];
}

void add(int x){
    if(!cnt[x]++) res++;
}

void sub(int x){
    if(!--cnt[x]) res--;
}

void calc(int x){
    if(!used[x]) add(a[x]);
    else sub(a[x]);
    used[x]^=1;
}

int main(){
    int n,m,u,v,siz,_lca;
    n=read(),m=read();
    init(n);
    for(int i=1;i<=n;i++)
        a[i]=tmp[i]=read();
    discretization(n);
    for(int i=1;i<=n-1;i++){
        u=read(),v=read();
        addEdge(u,v);
    }
    dfs(1);
    siz=sqrt(num);//siz块大小 是括号序列的平方根,因为每条询问的lca l r记录的是括号序列的下标
    for(int i=1;i<=num;i++)
        pos[i]=(i-1)/siz+1;
    for(int i=1;i<=m;i++){
        u=read(),v=read();
        _lca=lca(u,v);
        if(st[u]>st[v]) swap(u,v);
        if(_lca != u) q[i]=seq(ed[u],st[v],_lca,i);
        else q[i]=seq(st[u],st[v],0,i);
    }
    sort(q+1,q+1+m);
    int l=1,r=0;
    for(int i=1;i<=m;i++){
        while(l<q[i].l) calc(b[l++]);
        while(l>q[i].l) calc(b[--l]);
        while(r<q[i].r) calc(b[++r]);
        while(r>q[i].r) calc(b[r--]);
        if(q[i].lca) calc(q[i].lca);
        ans[q[i].id]=res;
        if(q[i].lca) calc(q[i].lca);
    }
    for(int i=1;i<=m;i++)
        printf("%d\n",ans[i]);
    return 0;
}

2020/7/5 00:34
加载中...