代码如下:
#include <cstdio>
using namespace std;
long long max(long long x,long long y)
{
if (x>y) return x;
return y;
}
long long min(long long x,long long y)
{
if (x>y) return y;
return x;
}
int n,a,i;
long long mod=998244353;
long long t,tj[200001*2],maxt,mint,ans;
int main()
{
scanf("%d",&n);
mint=2e9;
for (i=1;i<=n;i++)
{
scanf("%d",&a);
maxt=max(a,maxt);
mint=min(a,mint);
tj[a]++;
}
for (i=mint;i<=maxt;i++) tj[i]+=tj[i-1];
for (i=mint;i<=maxt;i++)
{
t=tj[i]-tj[i-1];
ans+=t*(t-1)/2%mod*(tj[min(i*2-1,maxt)]-t)%mod;
ans+=t*(t-1)*(t-2)/6%mod;
}
printf("%lld",ans);
}