显然可以想到 O(n2) 的树形 dp(枚举谁是根之后遍历树),显然需要优化。
根据经验,CCF 的出题数据一半都是随机的,树的层数应该不会太小,所以换根是枚举变化的路径不多,考虑记忆化路径,下次走到相同的路径直接返回,不用再递归了。
实测快的飞起。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10;
int mod=1e9;
vector<int>e[N];
int dp[N];
int k[N];
unordered_map<int,int>mp[N];
void init(){
k[0]=1;
for(int i=1;i<N;i++){
k[i]=k[i-1]*i;
k[i]%=mod;
}
return;
}
void dfs(int id,int fa){
dp[id]=1;
int cnt=0;
//cout<<id<<endl;
for(int to:e[id]){
if(to==fa)continue;
int f=mp[id][to];
if(f){//直接调取
dp[id]*=f;
dp[id]%=mod;
}
else{
dfs(to,id);
dp[id]*=dp[to];
dp[id]%=mod;
mp[id][to]=dp[to];//记忆路径
}
cnt++;
}
//cout<<id<<endl;
dp[id]*=k[cnt];
dp[id]%=mod;
return;
}
signed main(){
init();
int n;
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
int ans=0;
for(int i=1;i<=n;i++){
dfs(i,0);
ans+=dp[i];
ans%=mod;
}
printf("%lld\n",ans);
}
~虽然经我在考场测试得,1e5的菊花图就炸了。~
如果是菊花图就会退化成 O(n2logn)。
但是我们不能拘泥于水数据里,必须找其他方法。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10;
int mod=1e9;
vector<int>e[N];
int dp[N];
int k[N];
unordered_map<int,int>mp[N];
unordered_map<int,bool>t[N];
void init(){
k[0]=1;
for(int i=1;i<N;i++){
k[i]=k[i-1]*i;
k[i]%=mod;
}
return;
}
int exgcd(int a,int b,int &x,int &y){
if(b==0){
x=1,y=0;
return a;
}
int ret=exgcd(b,a%b,y,x);
y-=a/b*x;
return ret;
}
int getInv(int a,int p){
int x,y;
int d=exgcd(a,p,x,y);
while(x<0)x+=p/d;
x%=p/d;
return x;
}
void dfs(int id,int fa){
dp[id]=1;
int cnt=0;
for(int to:e[id]){
if(to==fa)continue;
int f=t[id][to];
if(f){//直接调取
dp[id]*=mp[id][to];
dp[id]%=mod;
}
else{
dfs(to,id);
dp[id]*=dp[to];
dp[id]%=mod;
mp[id][to]=dp[to];//记忆路径
t[id][to]=1;
}
cnt++;
}
int p2=dp[id];
if(cnt)p2=(p2*k[cnt-1])%mod;
dp[id]*=k[cnt];
dp[id]%=mod;
for(int to:e[id]){
if(to==fa)continue;
int f=t[to][id];
if(f)continue;
int p=(p2*getInv(dp[to],mod))%mod;
p%=mod;
mp[to][id]=p;
t[to][id]=1;
}
return;
}
signed main(){
init();
int n;
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
int ans=0;
for(int i=1;i<=n;i++){
dfs(i,0);
ans+=dp[i];
ans%=mod;
}
printf("%lld\n",ans);
return 0;
}
这个不炸了。
欢迎编程界的各位大佬 hack 我的代码,如题解有任何问题请联系我。
ps:数学不够,dp 来凑。