代码:
#include<bits/stdc++.h>
using namespace std;
bool had[1000005];
int n,m,num[1000005],zhq[1000005],ze,sum;
int solve(int s,int z)
{
int re = 0;
for(int i = s - 1;i >= 1;i--)
{
if(z < num[s] - num[i]) return re;
else if(z == num[s] - num[i]) re++;
else re += solve(i,z - num[s] + num[i]);
}
return re;
}
signed main()
{
cin >> n >> m;
int in;
for(int i = 1;i <= n;i++)
{
cin >> in;
if(in > m) had[in] = true;
if(in == 0) ze++;
else if(!had[zhq[in]]) zhq[in] = sum,had[zhq[in]] = true,num[sum]++;
else num[zhq[in]]++;
}
sort(num + 1,num + sum + 1);
cout << solve(sum,ze) << '\n';
return 0;
}
zhq是转换器,以防数组越界,如果比m大就忽略。