做法和题解第三篇思路差不多,但是过不了样例。
Code:
#include<bits/stdc++.h>
using namespace std;int p,inv2,inv4,inv6;const int Max=4641588;int phi[Max+5],_sum[Max+5];
bool visited[Max+5];int prime[365162],cnt;inline void get_sum(){phi[1]=1;for(register int i=2;i<=Max;++i)
{if(!visited[i]){prime[++cnt]=i;phi[i]=i-1;}int tmp;for(register int j=1;j<=cnt&&(tmp=i*prime[j])<=Max;
++j){visited[tmp]=1;if(i%prime[j]){phi[tmp]=phi[i]*(prime[j]-1);}else{phi[tmp]=phi[i]*prime[j];break;}}}
for(register int i=1;i<=Max;++i){_sum[i]=(_sum[i-1]+1ll*i*i%p*phi[i])%p;}return;}unordered_map<int,int>M;
int f(long long x){if(x<=Max)return _sum[x];if(M[x])return M[x];int res=1ll*x*x%p*(x+1)%p*(x+1)%p*inv4%p;
for(register long long l=2,r;l<=x;l=r+1){r=x/(x/l);
res=(res-((r*r%p*inv2%p)-((l-1)*(l-1)%p*inv2%p)+p)%p*f(x/l)+p)%p;}return res;}
inline int ksm(int x,int y){int res=1;while(y){if(y&1)res=1ll*res*x%p;x=1ll*x*x%p;y>>=1;}return res;}
inline int g(long long x){return x%p*(x+1)%p*(2*x+1)%p*inv6%p;}signed main(){int ans;long long n;
scanf("%d%lld",&p,&n);inv2=ksm(2,p-2);inv4=ksm(4,p-2);inv6=ksm(6,p-2);get_sum();
for(register long long l=1,r;l<=n;l=r+1){r=n/(n/l);ans=(ans+1ll*g((n/l)%p)*(f(r)-f(l-1)+p)%p)%p;}
printf("%d\n",ans);return 0;}
如果您不想看这一坨史山代码,可以看经过格式化后的:
#include <bits/stdc++.h>
using namespace std;
int p, inv2, inv4, inv6;
const int Max = 4641588;
int phi[Max + 5], _sum[Max + 5];
bool visited[Max + 5];
int prime[365162], cnt;
inline void get_sum() {
phi[1] = 1;
for (register int i = 2; i <= Max; ++i) {
if (!visited[i]) {
prime[++cnt] = i;
phi[i] = i - 1;
}
int tmp;
for (register int j = 1; j <= cnt && (tmp = i * prime[j]) <= Max; ++j) {
visited[tmp] = 1;
if (i % prime[j]) {
phi[tmp] = phi[i] * (prime[j] - 1);
} else {
phi[tmp] = phi[i] * prime[j];
break;
}
}
}
for (register int i = 1; i <= Max; ++i) {
_sum[i] = (_sum[i - 1] + 1ll * i * i % p * phi[i]) % p;
}
return;
}
unordered_map<int, int> M;
int f(long long x) {
if (x <= 4641588)
return _sum[x];
if (M[x])
return M[x];
int res = 1ll * x * x % p * (x + 1) % p * (x + 1) % p * inv4 % p;
for (register long long l = 2, r; l <= x; l = r + 1) {
r = x / (x / l);
res =
(res - ((r * r % p * inv2 % p) - ((l - 1) * (l - 1) % p * inv2 % p) + p) % p * f(x / l) + p) % p;
}
return res;
}
inline int ksm(int x, int y) {
int res = 1;
while (y) {
if (y & 1)
res = 1ll * res * x % p;
x = 1ll * x * x % p;
y >>= 1;
}
return res;
}
inline int g(long long x) { return x * (x + 1) % p * (2 * x + 1) % p * inv6 % p; }
signed main() {
int ans;
long long n;
scanf("%d%lld", &p, &n);
inv2 = ksm(2, p - 2);
inv4 = ksm(4, p - 2);
inv6 = ksm(6, p - 2);
get_sum();
for (register long long l = 1, r; l <= n; l = r + 1) {
r = n / (n / l);
ans = (ans + 1ll * g((n / l) % p) * (f(r) - f(l - 1) + p) % p) % p;
}
printf("%d\n", ans);
return 0;
}