个人代码:
#include<bits/stdc++.h>
using namespace std;
const int N=100;
const int M=200;
int n,m;
vector<int> v[N+5];
int w[N+5];
int dp[N+5][M+5][2];
void dfs(int node,int father)
{
for(auto son:v[node])
{
if(son==father) continue;
dfs(son,node);
}
for(int j=m;j>=0;j--) dp[node][j][0]=dp[node][j][1]=w[node];
for(auto son:v[node])
{
if(son==father) continue;
for(int j=m;j>=0;j--)
{
dp[node][j][1]=max(dp[node][j][1],dp[node][0][0]+dp[son][j-1][1]);
for(int k=j-2;k>=0;k--)
{
dp[node][j][0]=max(dp[node][j][0],dp[node][j-k-2][0]+dp[son][k][0]);
dp[node][j][1]=max(dp[node][j][1],dp[node][j-k-1][0]+dp[son][k][1]);
}
}
}
}
void Clear()
{
for(int i=1;i<=n;i++)
{
v[i].clear();
}
}
void solve()
{
cin>>m;
Clear();
for(int i=1;i<=n;i++)
{
cin>>w[i];
}
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1,0);
cout<<max(dp[1][m][0],dp[1][m][1])<<"\n";
}
signed main()
{
while(scanf("%d",&n)!=EOF)
{
solve();
}
return 0;
}