求助树上莫队T了 Sp10707
查看原帖
求助树上莫队T了 Sp10707
139012
wrpwrp楼主2020/4/29 13:55

我本地跑了贼多数据一点事没有,拿题解对拍甚至跑得比题解快,但他T了....
求大佬帮忙看看吧...
代码:

#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

#define R register

const int MAXN=1e5+10;
const int MAXM=1e5+10;

inline int read() {
	char a=getchar(); int x=0,f=1;
	for(;a>'9'||a<'0';a=getchar()) if(a=='-') f=-1;
	for(;a>='0'&&a<='9';a=getchar()) x=x*10+a-'0';
	return x*f;
}

inline void Print(int x) {
	if(x>9) Print(x/10);
	putchar(x%10+'0');
}

inline void print(int x) { if(x<0) putchar('-'),x=-x; Print(x); putchar('\n'); }

int n,m;
int a[MAXN],b[MAXN],block[MAXN];

vector<int> Edge[MAXN];
inline void addedge(int x,int y) { Edge[x].push_back(y); }

struct Ques { int x,y,l,r,id,flg; } q[MAXM];
inline bool cmp(Ques x,Ques y) { return block[x.l]==block[y.l]?x.r<y.r:x.l<y.l; }

inline void init() {
	n=read(); m=read(); int len=sqrt(n*2);
	for(R int i=1;i<=n;i++) a[i]=b[i]=read();
	int x,y;
	for(R int i=1;i<n;i++) x=read(),y=read(),addedge(x,y),addedge(y,x);
	for(R int i=1;i<=m;i++) { scanf("%d%d",&q[i].x,&q[i].y); q[i].id=i; }
	for(R int i=1;i<=n*2;i++) block[i]=(i-1)/len+1; 
	sort(b+1,b+1+n); int N=unique(b+1,b+1+n)-b-1;
	for(R int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+N,a[i])-b;
	//for(R int i=1;i<=n;i++) printf("%d ",a[i]); putchar('\n');
}

int ord[MAXN],num,st[MAXN],ed[MAXN],dep[MAXN],fa[MAXN],top[MAXN],siz[MAXN],son[MAXN];

inline void dfs1(int x,int fx) {
	fa[x]=fx; dep[x]=dep[fx]+1;
	st[x]=++num; ord[num]=x;
	for(R int i=0;i<Edge[x].size();i++) {
		int y=Edge[x][i];
		if(y==fx) continue;
		dfs1(y,x); siz[x]+=siz[y];
		if(siz[y]>siz[son[x]]) son[x]=y;
	}
	ed[x]=++num; ord[num]=x;
}

inline void dfs2(int x,int topx) {
	top[x]=topx;
	if(son[x]) dfs2(son[x],topx);
	for(R int i=0;i<Edge[x].size();i++) {
		int y=Edge[x][i];
		if(y==fa[x]) continue;
		dfs2(y,y);
	}
}

inline int lca(int x,int y) {
	while(top[x]!=top[y]) 
		if(dep[top[x]]>dep[top[y]]) x=fa[top[x]];
		else y=fa[top[y]];
	return dep[x]<dep[y]?x:y;
}

int Ans[MAXM];

inline void pre() {
	dfs1(1,0); dfs2(1,0);
	//for(R int i=1;i<=n*2;i++)
	//	printf("%d ",ord[i]); putchar('\n');
	for(R int i=1;i<=m;i++) {
		int x=q[i].x,y=q[i].y;
		int Lca=lca(x,y);
		if(st[x]>st[y]) swap(x,y);
		if(Lca==x||Lca==y) {
			q[i].l=st[x];
			q[i].r=st[y];
		}
		else {
			q[i].l=ed[x];
			q[i].r=st[y];
			q[i].flg=Lca;
		}
		//printf("%d %d\n",q[i].l,q[i].r);
	}
	sort(q+1,q+1+m,cmp);
}

int ans=0;
int cnt[MAXN],used[MAXN];

inline void Add(int node) {
	cnt[a[node]]++;
	if(cnt[a[node]]==1) ans++;
}

inline void Del(int node) {
	cnt[a[node]]--;
	if(cnt[a[node]]==0) ans--;
}

inline void calc(int pos) {
	int node = ord[pos];
	if(used[node]) Del(node); else Add(node);
	used[node]^=1;
}

inline void solve() {
	int l=1,r=0;
	for(R int i=1;i<=m;i++) {
		int ql=q[i].l;
		int qr=q[i].r;
		//printf("%d %d %d\n",ql,qr,q[i].flg);
		while(l<ql) calc(l++);
		while(l>ql) calc(--l);
		while(r<qr) calc(++r);
		while(r>qr) calc(r--);
		//printf("%d\n",ans);
		if(q[i].flg&&cnt[a[q[i].flg]]==0) Ans[q[i].id]=ans+1;
		else Ans[q[i].id]=ans;
	}
	for(R int i=1;i<=m;i++)
		print(Ans[i]);
}

int main() {
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
	init();
	pre();
	solve();
	return 0;
}

对拍用的代码:

#include<cstdio>
#include<cmath>
#include<algorithm>
#include<vector>
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
using namespace std;
const int MAXN = 1e5 + 10;
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
int N, Q;
int belong[MAXN], block;
struct Query {
    int l, r, ID, lca, ans;
    bool operator < (const Query &rhs) const{
        return belong[l] == belong[rhs.l] ? r < rhs.r : belong[l] < belong[rhs.l];
    //    return belong[l] < belong[rhs.l];
    }
}q[MAXN];
vector<int>v[MAXN];
int a[MAXN], date[MAXN];
void Discretization() {
    sort(date + 1, date + N + 1);
    int num = unique(date + 1, date + N + 1) - date - 1;
    for(int i = 1; i <= N; i++) a[i] = lower_bound(date + 1, date + num + 1, a[i]) - date;    
}
int deep[MAXN], top[MAXN], fa[MAXN], siz[MAXN], son[MAXN], st[MAXN], ed[MAXN], pot[MAXN], tot;
void dfs1(int x, int _fa) {
    fa[x] = _fa; siz[x] = 1;
    st[x] = ++ tot; pot[tot] = x; 
    for(int i = 0; i < v[x].size(); i++) {
        int to = v[x][i];
        if(deep[to]) continue;
        deep[to] = deep[x] + 1;
        dfs1(to, x);
        siz[x] += siz[to];
        if(siz[to] > siz[son[x]]) son[x] = to;
    }
    ed[x] = ++tot; pot[tot] = x;
}
void dfs2(int x, int topfa) {
    top[x] = topfa;
    if(!son[x]) return ;
    dfs2(son[x], topfa);
    for(int i = 0; i < v[x].size(); i++) {
        int to = v[x][i];
        if(top[to]) continue;
            dfs2(to, to);
    }
}
int GetLca(int x, int y) {
    while(top[x] != top[y]) {
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    return deep[x] < deep[y] ? x : y;
}
void DealAsk() {
    for(int i = 1; i <= Q; i++) {
        int x = read(), y = read();
        if(st[x] > st[y]) swap(x, y);
        int _lca = GetLca(x, y);
        q[i].ID = i;
        if(_lca == x) q[i].l = st[x], q[i]. r = st[y];
        else q[i].l = ed[x], q[i].r = st[y], q[i].lca = _lca;
    }
}
int Ans, out[MAXN], used[MAXN], happen[MAXN];
void add(int x) {
    if(++happen[x] == 1) Ans++;
}
void delet(int x) {
    if(--happen[x] == 0) Ans--;
}
void Add(int x) {
    used[x] ? delet(a[x]) : add(a[x]); used[x] ^= 1;
}
void Mo() {
    sort(q + 1, q + Q + 1);
    int l = 1, r = 0, fuck = 0;
    for(int i = 1; i <= Q; i++) {
        while(l < q[i].l) Add(pot[l]), l++, fuck++;
        while(l > q[i].l) l--, Add(pot[l]), fuck++;
        while(r < q[i].r) r++, Add(pot[r]), fuck++;
        while(r > q[i].r) Add(pot[r]), r--, fuck++;
        if(q[i].lca) Add(q[i].lca);
        q[i].ans = Ans;
        if(q[i].lca) Add(q[i].lca);
    }
    for(int i = 1; i <= Q; i++) out[q[i].ID] = q[i].ans;
    for(int i = 1; i <= Q; i++)
        printf("%d\n", out[i]);
}
int main() {
freopen("a.in","r",stdin);
freopen("b.out","w",stdout);
    N = read(); Q = read();
    //block = 1.5 * sqrt(2 * N) + 1;
    //block = pow(N, 0.66666666666);
    block = sqrt(N);
    for(int i = 1; i <= N; i++) a[i] = date[i] = read();
    for(int i = 1; i <= N * 2; i++) belong[i] = i / block + 1;
    Discretization();
    for(int i = 1; i <= N - 1; i++) {
        int x = read(), y = read();
        v[x].push_back(y); v[y].push_back(x);
    }
    deep[1] = 1; dfs1(1, 0);
    dfs2(1, 1);
/*    for(int i = 1; i <= N; i++)    
        for(int j = 1; j <= i - 1; j++)
            printf("%d %d %d\n", i, j, GetLca(i, j));*/
    DealAsk();
    Mo();
    return 0;
}

生成器:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <ctime>
using namespace std;
int main() {
	freopen("a.in","w",stdout);
	srand(time(0));
	int n=40000,m=100000;
	printf("%d %d\n",n,m);
	for(int i=1;i<=n;i++) printf("%d\n",rand());
	for(int i=2;i<=n;i++) printf("%d %d\n",i,rand()%(i-1)+1);
	for(int i=1;i<=m;i++) 
		printf("%d %d\n",rand()%n+1,rand()%n+1);
	return 0;
}

checker:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
int main() {
	system("g++ -o a a.cpp");
	system("g++ -o b b.cpp");
	system("g++ -o m m.cpp");
	int tot=0;
	while(1) {
		system("./m");
		system("./a");
		system("./b");
		if(system("diff a.out b.out")) break;
		else printf("Case: %d Accept! \n",++tot);
	}
	return 0;
}
2020/4/29 13:55
加载中...