我给每条链分配 2 倍空间就能过,1 倍空间过不了,但理论上 1 倍空间就能过呀,为什么?
// AC代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll n,ans;
vector<ll>eg[5005];
ll len[5005],son[5005];
ll buf1[10005],*dp[5005],*t1=buf1;
ll buf2[10005],*g[5005],*t2=buf2;
void dfs_build(ll fa,ll p){
for(ll v:eg[p]){
if(v==fa)continue;
dfs_build(p,v);
if(len[son[p]]<len[v])son[p]=v;
}
len[p]=len[son[p]]+1;
}
void dfs_dp(ll fa,ll p){
g[p][0]=1;
if(son[p]){
dp[son[p]]=dp[p]-1;
g[son[p]]=g[p]+1;
dfs_dp(p,son[p]);
ans+=dp[p][0];
}
for(ll v:eg[p]){
if(v==fa||v==son[p])continue;
dp[v]=t1+len[v],t1+=len[v]<<1;
g[v]=t2,t2+=len[v]<<1;
dfs_dp(p,v);
for(ll i=1;i<=len[v];i++)ans+=dp[p][i]*g[v][i-1];
for(ll i=1;i<len[v];i++)ans+=g[p][i-1]*dp[v][i];
for(ll i=1;i<=len[v];i++)dp[p][i]+=g[p][i]*g[v][i-1];
for(ll i=1;i<len[v];i++)dp[p][i-1]+=dp[v][i];
for(ll i=1;i<=len[v];i++)g[p][i]+=g[v][i-1];
}
}
signed main(){
ios::sync_with_stdio(false);
cin>>n;
for(ll i=1;i<n;i++){
ll x,y;cin>>x>>y;
eg[x].push_back(y);
eg[y].push_back(x);
}
dfs_build(0,1);
dp[1]=t1+len[1],t1+=len[1]<<1;
g[1]=t2,t2+=len[1]<<1;
dfs_dp(0,1);
cout<<ans;
return 0;
}
// 40pts 代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll n,ans;
vector<ll>eg[5005];
ll len[5005],son[5005];
ll buf1[10005],*dp[5005],*t1=buf1;
ll buf2[10005],*g[5005],*t2=buf2;
void dfs_build(ll fa,ll p){
for(ll v:eg[p]){
if(v==fa)continue;
dfs_build(p,v);
if(len[son[p]]<len[v])son[p]=v;
}
len[p]=len[son[p]]+1;
}
void dfs_dp(ll fa,ll p){
g[p][0]=1;
if(son[p]){
dp[son[p]]=dp[p]-1;
g[son[p]]=g[p]+1;
dfs_dp(p,son[p]);
ans+=dp[p][0];
}
for(ll v:eg[p]){
if(v==fa||v==son[p])continue;
dp[v]=t1+len[v],t1+=len[v];
g[v]=t2,t2+=len[v];
dfs_dp(p,v);
for(ll i=1;i<=len[v];i++)ans+=dp[p][i]*g[v][i-1];
for(ll i=1;i<len[v];i++)ans+=g[p][i-1]*dp[v][i];
for(ll i=1;i<=len[v];i++)dp[p][i]+=g[p][i]*g[v][i-1];
for(ll i=1;i<len[v];i++)dp[p][i-1]+=dp[v][i];
for(ll i=1;i<=len[v];i++)g[p][i]+=g[v][i-1];
}
}
signed main(){
ios::sync_with_stdio(false);
cin>>n;
for(ll i=1;i<n;i++){
ll x,y;cin>>x>>y;
eg[x].push_back(y);
eg[y].push_back(x);
}
dfs_build(0,1);
dp[1]=t1+len[1],t1+=len[1];
g[1]=t2,t2+=len[1];
dfs_dp(0,1);
cout<<ans;
return 0;
}