记录出发点个数的数组开不开long long的问题
查看原帖
记录出发点个数的数组开不开long long的问题
449265
wind_whisper楼主2021/2/23 00:11

我记录出发点个数的数组不开longlong只能有40分,开longlong就切了...可是出发点不应该最多只有N*N个吗?为什么会爆int呢?迷惑ing qaq

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=805;
#define ll long long
int m,t,n;
ll a,b;
ll cnt=0;
ll tot=0;
ll fa[N*N],num[N*N];
ll keynum[N*N];//keynum[i]表示以i点为根节点的集合中出发点的个数
ll mx=0;
ll high[N][N],jd[N][N];
struct node{
	ll x,y,v;
}p[N*N];
bool cmp(node a,node b){
  	return a.v<b.v;
}
int hash(int x,int y){
  	return N*x+y;
}
int find(int x){
  	if(fa[x]==x) return x;
	else return fa[x]=find(fa[x]);
}
void merge(int x,int y){//把y接到x上
  	if(find(x)==find(y)) return;
  	num[fa[x]] += num[fa[y]];
  	keynum[fa[x]] += keynum[fa[y]];
  	fa[fa[y]]=fa[x];
//  printf()
}
int main(){
//	freopen("P3101_2.in","r",stdin);
//  freopen("a.out","w",stdout);
	scanf("%lld%lld%lld",&m,&n,&t);
	for(int i=1;i<=m;i++){
		for(int j=1;j<=n;j++){
			scanf("%lld",&high[i][j]);
    		num[hash(i,j)]=1;
      		fa[hash(i,j)]=hash(i,j);
    	}
  	}
  	for(int i=1;i<=m;i++){
    	for(int j=1;j<=n;j++){
      		scanf("%lld",&jd[i][j]);
      		keynum[hash(i,j)]=jd[i][j];
      		if(jd[i][j]) tot++;
    	}
	}
  	for(int i=1;i<=m-1;i++){//记录边(写的有些啰嗦。。。) 
    	for(int j=1;j<=n-1;j++){
      		p[++cnt]=(node){hash(i,j),hash(i,j+1),abs(high[i][j]-high[i][j+1])};//右边的边
      		p[++cnt]=(node){hash(i,j),hash(i+1,j),abs(high[i][j]-high[i+1][j])};//下边的边
    	}
    	p[++cnt]=(node){hash(i,n),hash(i+1,n),abs(high[i][n]-high[i+1][n])};
  	}
  	for(int j=1;j<=n-1;j++){
    	p[++cnt]=(node){hash(m,j),hash(m,j+1),abs(high[m][j]-high[m][j+1])};
  	}
  	sort(p+1,p+1+cnt,cmp);
  	long long ans=0;
//  printf("hello,world! %d\n",tot);
//  for(int i=1;i<=cnt;i++){
//	    printf("%d %d %d %d %d\n",p[i].x/N,p[i].x%N,p[i].y/N,p[i].y%N,p[i].v);
//  }
//  for(int i=1;i<=n;i++){
//    	for(int j=1;j<=n;j++){
//      		printf("%d ",keynum[N*i+j]);
//    	}
//  	printf("\n");
//  }
  	for(int i=1;tot;i++){//不断加边到tot减到0为止
    	a=find(p[i].x);
    	b=find(p[i].y);
    	merge(a,b);
//    	printf("x1=%d y1=%d x2=%d y2=%d v=%d num=%d keynum=%d\n",
//		p[i].x/N,p[i].x%N,p[i].y/N,p[i].y%N,p[i].v,num[a],keynum[a]);
    	if(num[a]>=t){
      		ans += keynum[a] * p[i].v;
      		tot -= keynum[a];
      		keynum[a]=0;
       	}
  	}
  	printf("%lld",ans);
	return 0;
}
/*
2 2 3
10 0
30 10
1 0
0 0
*/
2021/2/23 00:11
加载中...