RT,楼主的程序在SPOJ上过了(戳我)但是在洛谷交上去每次都显示CE,求大佬帮忙,谢谢!QWQ
附代码(看不看无所谓,应该没有什么问题)
/* 树上莫队 */
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define rep(i, m, n) for(int i = m; i <= n; i++)
#define per(i, m, n) for(int i = m; i >= n; i--)
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
#define vi vector<int>
#define vll vector<ll>
#define sz(v) (int) v.size()
const int INF = 0x3f3f3f3f;
template<class Ty>
inline void read(Ty & X) {
X = 0; Ty flag = 1; char ch = getchar();
while(ch < '0' || ch > '9') { if (ch == '-') flag = 0; ch = getchar(); }
while(ch >= '0' && ch <= '9') { X = (X << 1) + (X << 3) + ch - '0'; ch = getchar(); }
if (!flag) X = ~(X - 1);
}
template<class Ty>
inline void write(Ty X) {
if (X < 0) { X = ~(X - 1); putchar('-'); }
if (X > 9) write(X / 10);
putchar(X % 10 + '0');
}
const int maxn = 2e5 + 10;
const int maxt = 20;
struct node {
int l, r, lca, id;
} T[maxn];
int pre[maxn][maxt + 5], depth[maxn], first[maxn], last[maxn], cnt[maxn], ans[maxn], belong[maxn], pos[maxn], a[maxn], vis[maxn], len, now, val[maxn], N, M;
vector<int> G[maxn];
bool cmp(node &a, node &b) {
return (belong[a.l] ^ belong[b.l]) ? (belong[a.l] < belong[b.l]) : ((belong[a.l] & 1) ? a.r < b.r : a.r > b.r);
}
void dfs(int u, int k) {
depth[u] = depth[k] + 1, pre[u][0] = k;
pos[++len] = u;
first[u] = len;
for (int i = 1; i <= maxt; i++) {
pre[u][i] = pre[pre[u][i - 1]][i - 1];
}
for (int i = 0; i < (int) G[u].size(); i++) {
int v = G[u][i];
if (v == k) continue;
dfs(v, u);
}
pos[++len] = u;
last[u] = len;
}
int find(int u, int v) {
if (depth[u] < depth[v]) swap(u ,v);
for (int i = maxt; i >= 0; i--) {
if (depth[pre[u][i]] >= depth[v])
u = pre[u][i];
}
if (u == v) return u;
for (int i = maxt; i >= 0; i--) {
if (pre[u][i] != pre[v][i])
u = pre[u][i], v = pre[v][i];
}
return pre[u][0];
}
void add(int ps) {
if (vis[ps]) now -= !--cnt[val[ps]];
else now += !cnt[val[ps]]++;
vis[ps] ^= 1;
}
int main() {
read(N), read(M);
rep(i, 1, N) read(a[i]), val[i] = a[i];
sort(a + 1, a + N + 1);
int tot = unique(a + 1, a + N + 1) - (a + 1);
rep(i, 1, N) val[i] = lower_bound(a + 1, a + tot + 1, val[i]) - a;
rep(i, 1, N - 1) {
int u, v;
read(u), read(v);
G[u].pb(v), G[v].pb(u);
}
dfs(1, 0);
int block = sqrt(2 * N);
int bnum = ceil((double) 2 * N / block);
rep(i, 1, bnum) rep(j, (i - 1) * block + 1, i * block) belong[j] = i;
rep(i, 1, M) {
int L, R;
read(L), read(R);
int lca = find(L, R);
if (first[L] > first[R]) swap(L, R);
if (L == lca) T[i].l = first[L], T[i].r = first[R];
else T[i].l = last[L], T[i].r = first[R], T[i].lca = lca;
T[i].id = i;
}
sort(T, T + M + 1, cmp);
int l = 1, r = 0;
rep(i, 1, M) {
int ql = T[i].l, qr = T[i].r, lca = T[i].lca;
while (l < ql) add(pos[l++]);
while (l > ql) add(pos[--l]);
while (r > qr) add(pos[r--]);
while (r < qr) add(pos[++r]);
if (lca) add(lca);
ans[T[i].id] = now;
if (lca) add(lca);
}
rep(i, 1, M) write(ans[i]), puts("");
return 0;
}