贴代码:
#include<bits/stdc++.h>
using namespace std;
const int MAX = 6e3 + 5;
int h[MAX], cnt = 1;
int n, r[MAX], f[MAX][2], root;
bool vis[MAX];
inline int read()
{
int x = 0;
bool t = 0;
char ch = getchar();
while((ch < '0' || ch > '9') && ch != '-')
{
ch = getchar();
}
if(ch == '-')
{
t = 1;
ch = getchar();
}
while(ch <= '9' && ch >= '0')
{
x = x * 10 + ch - 48;
ch = getchar();
}
return t?-x:x;
}
struct line
{
int v, next;
}e[MAX];
void Add(int u, int v)
{
e[cnt] = (line){v, h[u]};
h[u] = cnt++;
}
void dfs(int u)
{
f[u][0] = 0;
f[u][1] = r[u];
for(int i = h[u];i;i = e[i].next)
{
int v = e[i].v;
dfs(v);
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
}
}
int main()
{
n = read();
for(int i = 1;i <= n;i++)
r[i] = read();
for(int i = 1;i <= n;i++)
{
int u = read(), v = read();
Add(v, u);
vis[u] = 1;
}
for(int i = 1;i <= n;i++)
if(!vis[i])
root = i;
dfs(root);
cout << max(f[root][0], f[root][1]) << endl;
return 0;
}