并查集合并的时候,考虑编号大的根挂到编号下的根下方
void merge(int a, int b) {
int ra = find(a);
int rb = find(b);
if (ra < rb) f[rb] = ra;
else f[ra] = rb;
}
如上代码只有55分,
但是把合并时候 if (ra < rb) 去掉,如下
void merge(int a, int b) {
int ra = find(a);
int rb = find(b);
f[ra] = rb;
}
就AC了
脑瘫中ing
完整代码如下:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e6 + 100;
const int mod = 1e9;
int f[4 * maxn + 10];
int x[maxn], y[maxn], z[maxn], e[maxn], n, m, k, OO;
int find(int x) {
if (f[x] == x) return x;
return f[x] = find(f[x]);
}
void merge(int a, int b) {
int ra = find(a);
int rb = find(b);
if (ra < rb) f[rb] = ra;
else f[ra] = rb;
}
int qpow(int a, int b) {
int ret = 1;
while (b) {
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
b /= 2;
}
return ret;
}
int solve(int OO) {
for (int i = 1; i <= k; ++i) {
if (x[i] == 1 || y[i] == 1)
e[i] = z[i];
else if (x[i] % 2 == 0 && y[i] % 2 == 0)
e[i] = (z[i] ^ OO ^ 1);
else
e[i] = (z[i] ^ OO);
}
for (int i = 1; i <= 2 * m + 2 * n + 2; ++i) {
f[i] = i;
}
int zero = 2 * m + 2 * n + 1;
int one = zero + 1;
for (int i = 1; i <= k; ++i) {
if (x[i] == 1 && y[i] == 1) continue;
int a = x[i] + m + m;
int b = y[i];
if (x[i] == 1) {
merge(b , zero + e[i]);
merge(b + m, zero + (1^e[i]));
} else if (y[i] == 1) {
merge(a , zero + e[i]);
merge(a + n, zero + (1^e[i]));
} else if (e[i] == 1) {
merge(a + n, b);
merge(a, b + m);
} else if (e[i] == 0) {
merge(a, b);
merge(a + n, b + m);
}
}
if (find(zero) == find(one)) return 0;
for (int i = 2; i <= m; ++i) {
if (find(i) == find(i + m)) return 0;
}
for (int i = 2; i <= n; ++i) {
if (find(2 * m + i) == find(2 * m + n + i)) return 0;
}
int ans = 0;
int ro = find(zero);
int re = find(one);
for (int i = 2; i <= m; ++i) {
if (ro == find(i)) continue;
if (re == find(i)) continue;
if (find(i) == i) ans ++;
// cout << find(i) << endl;
}
for (int i = 2; i <= n; ++i) {
if (ro == find(2 * m + i)) continue;
if (re == find(2 * m + i)) continue;
if (find(2 * m + i) == 2 * m + i) ans ++;
// cout << find(2 * m + i) << endl;
}
// cout << ans << endl;
return qpow(2, ans);
}
signed main() {
cin >> n >> m >> k;
OO = -1;
for (int i = 1; i <= k; ++i) {
scanf("%lld %lld %lld", &x[i], &y[i], &z[i]);
if (y[i] == 1 && x[i] == 1) {
OO = z[i];
}
}
if (OO != -1) {
cout << solve(OO) % mod << endl;
}
else
cout << (solve(0) + solve(1)) % mod << endl;
return 0;
}