这是自己一个私题的代码
但不明白为啥出现了一个问题:
在读入样例时,用cin读入type数组能得到正解
而用快读读入type数组得到的结果是错的
输入样例:
10 11
0 1 0 1 0 1 0 1 0 1
1 2
1 3
1 4
2 5
2 6
3 7
4 8
5 9
9 10
5 6
2 7
输出样例:
6050
代码
/*
Work by: Suzt_ilymics
Knowledge: ??
Time: O(??)
*/
#include<iostream>
#include<cstdio>
using namespace std;
const int MAXN = 1e5+5;
const int MAXM = 2e5+5;
const int mod = 1000000007;
struct edge{
int to, w, nxt;
}e[MAXM];
int head[MAXN], num_edge;
int n, m;
int dis[MAXN], fath[MAXN], f[MAXN], siz[MAXN];
bool type[MAXN];
int read(){
int w = 1, s = 0;
char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') w = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') s = s* 10 + ch - '0', ch = getchar();
return s * w;
}
//bool read1(){
// bool s = 0;
// char ch = getchar();
// while(ch >= '0' && ch <= '9') s = s* 10 + ch - '0', ch = getchar();
// return s;
//}
void add(int from, int to, int w){
e[++num_edge].to = to;
e[num_edge].w = w;
e[num_edge].nxt = head[from];
head[from] = num_edge;
}
int find(int x){return fath[x] == x ? x : fath[x] = find(fath[x]); }
void dfs(int x, int fa){
for(int i = head[x]; i; i = e[i].nxt){
int v = e[i].to;
if(v == fa) continue;
dis[v] = (dis[x] + e[i].w) % mod;
dfs(v, x), siz[x] += siz[v];
}
}
void dfs2(int x, int fa){
for(int i = head[x]; i; i = e[i].nxt){
int v = e[i].to;
if(v == fa) continue;
int jia = ((siz[1] - siz[v]) * e[i].w % mod + mod) % mod;
int jian = siz[v] * e[i].w % mod;
f[v] = ((f[x] + jia - jian) % mod + mod) % mod;
dfs2(v, x);
}
}
int main()
{
n = read(), m = read();
bool flag, se;
se = read();
if(se) flag = 1, type[1] = se ^ 1;
else type[1] = 0;
for(int i = 2; i <= n; ++i){
// type[i] = (bool)read();
cin>>type[i];
if(flag) type[i] ^= 1;
siz[i] = type[i];//顺便处理在以i为根的子树中,与1节点不同的点的个数
}
for(int i = 1; i <= n; ++i) fath[i] = i;
int x = 2, cnt = 0;
for(int i = 1, u, v; i <= m; ++i){
u = read(), v = read();
int uf = find(u), vf = find(v);
if(uf != vf){
fath[uf] = vf;
add(u, v, x), add(v, u, x);
if(++cnt == n - 1) break;
}
x = x * 2 % mod;
}
dfs(1, 1);
for(int i = 1; i <= n; ++i){
if(type[i]) f[1] = (f[1] + dis[i]) % mod;//先暴力把1与其他点的权值和加起来
}
dfs2(1, 1);
int ans = 0;
for(int i = 1; i <= n; ++i){
if(!type[i]) ans = (ans + f[i]) % mod;
}
printf("%d", ans);
return 0;
}