RT,只有 80pts /kk
代码:
#include <iostream>
#include <queue>
#include <cstdio>
using namespace std;
typedef long long ll;
typedef struct {
int nxt;
int start;
int end;
int dis;
} Edge;
int tree_edge_cnt = 0, loop_edge_cnt = 0;
int root[1000007], head[1000007], fa[1000007], dot[1000007], id[1000007];
ll dis[1000007], dis_[2000007], dp[1000007], dp_[2000007], sum[2000007], temp[2000007];
bool mark[1000007];
Edge edge1[2000007], edge2[1000007];
deque<int> q;
inline void init(int n){
for (register int i = 1; i <= n; i++){
root[i] = i;
}
}
inline int read(){
int ans = 0;
char ch = getchar();
while (ch < '0' || ch > '9'){
ch = getchar();
}
while (ch >= '0' && ch <= '9'){
ans = ans * 10 + (ch ^ 48);
ch = getchar();
}
return ans;
}
int get_root(int x){
if (root[x] == x) return x;
return root[x] = get_root(root[x]);
}
inline void add_tree_edge(int start, int end, int dis){
tree_edge_cnt++;
edge1[tree_edge_cnt].nxt = head[start];
head[start] = tree_edge_cnt;
edge1[tree_edge_cnt].end = end;
edge1[tree_edge_cnt].dis = dis;
}
inline void add_loop_edge(int start, int end, int dis){
loop_edge_cnt++;
edge2[loop_edge_cnt].start = start;
edge2[loop_edge_cnt].end = end;
edge2[loop_edge_cnt].dis = dis;
}
void dfs1(int u, int father){
fa[u] = father;
for (register int i = head[u]; i != 0; i = edge1[i].nxt){
int x = edge1[i].end;
if (x != father){
dis[x] = edge1[i].dis;
dfs1(x, u);
}
}
}
ll dfs2(int u, int father){
for (register int i = head[u]; i != 0; i = edge1[i].nxt){
int x = edge1[i].end;
if (x != father && !mark[x]) dp[u] = max(dp[u], dfs2(x, u) + edge1[i].dis);
}
return dp[u];
}
int main(){
int n = read();
ll ans = 0;
init(n);
for (register int i = 1; i <= n; i++){
int u = read(), l = read();
if (i == u) l = 0;
if (get_root(i) != get_root(u)){
root[root[i]] = root[u];
add_tree_edge(i, u, l);
add_tree_edge(u, i, l);
} else {
add_loop_edge(i, u, l);
}
}
for (register int i = 1; i <= loop_edge_cnt; i++){
int u = edge2[i].start, v = edge2[i].end, dot_cnt = 0, pos = 0, t;
ll cur_ans = 0;
dfs1(u, 0);
for (register int j = v; j != 0; j = fa[j]){
dot_cnt++;
dot[dot_cnt] = j;
id[j] = dot_cnt;
mark[j] = true;
}
dis[u] = edge2[i].dis;
t = dot_cnt * 2;
for (register int j = 1; j <= dot_cnt; j++){
dfs2(dot[j], 0);
dis_[j + dot_cnt] = dis_[j] = dis[dot[j]];
dp_[j] = dp_[j + dot_cnt] = dp[dot[j]];
}
for (register int j = 1; j <= t; j++){
sum[j] = sum[j - 1] + dis_[j - 1];
temp[j] = dp_[j] - sum[j];
}
for (register int j = 1; j <= dot_cnt; j++){
ll max_val = 0, second_max_val = 0;
for (register int k = head[dot[j]]; k != 0; k = edge1[k].nxt){
ll t = dp[edge1[k].end] + edge1[k].dis;
if (max_val < t){
second_max_val = max_val;
max_val = t;
} else if (second_max_val < t){
second_max_val = t;
}
}
cur_ans = max(cur_ans, max_val + second_max_val);
}
while (!q.empty()) q.pop_back();
for (register int j = 1; j <= t; j++){
while (!q.empty() && temp[j] >= temp[q.back()]) q.pop_back();
if (j >= dot_cnt){
pos++;
while (!q.empty() && q.front() < pos) q.pop_front();
}
if (!q.empty()) cur_ans = max(cur_ans, dp_[j] + sum[j] + temp[q.front()]);
q.push_back(j);
}
ans += cur_ans;
for (register int j = v; j != 0; j = fa[j]) mark[j] = false;
}
cout << ans;
return 0;
}