洛谷 P2607 [ZJOI2008]骑士

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <iostream>
using namespace std;
using LL = long long;
const int N = 1000002;
int n, root;
int h[N], w[N], e[N], ne[N], idx, fa[N];
bool vis[N];
LL f[N][2], ans = 0;
void add(int u, int v) {
e[++idx] = v, ne[idx] = h[u], h[u] = idx;
}
void dp(int u) {
vis[u] = true;
f[u][0] = 0, f[u][1] = w[u];
for (int i = h[u]; i; i = ne[i]) {
int v = e[i];
if (v != root) {
dp(v);
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
} else {
f[v][1] = -N;
}
}
}
void dfs(int u) {
vis[u] = true;
root = u;
while (!vis[fa[root]]) {
root = fa[root];
vis[root] = true;
}
dp(root);
LL tmp = max(f[root][0], f[root][1]);
vis[root] = true;
root = fa[root];
dp(root);
ans += max(tmp, max(f[root][0], f[root][1]));
}
int main() {
scanf("%d", &n);
for (int v = 1, u; v <= n; ++v) {
scanf("%d%d", &w[v], &u);
add(u, v);
fa[v] = u;
}
for (int i = 1; i <= n; ++i) {
if (vis[i]) continue;
dfs(i);
}
printf("%lld\n", ans);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
using namespace std;
using LL = long long;
using PII = pair<int, int>;
const int N = 1000002, M = N * 2;
int n, c, p[N];
PII root[N / 2];
int h[N], w[N], e[M], ne[M], idx;
LL f[N][2], ans;
int find(int x) {
return x == p[x] ? x : p[x] = find(p[x]);
}
void add(int u, int v) {
e[++idx] = v, ne[idx] = h[u], h[u] = idx;
}
LL dp(int u, int fa) {
f[u][0] = 0, f[u][1] = w[u];
for (int i = h[u]; i; i = ne[i]) {
int v = e[i];
if (v == fa) continue;
dp(v, u);
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
}
return f[u][0];
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) p[i] = i;
for (int u = 1, v; u <= n; ++u) {
scanf("%d%d", &w[u], &v);
if (find(u) != find(v)) {
p[find(u)] = find(v);
add(u, v), add(v, u);
} else {
root[c++] = {u, v};
}
}
for (int i = 0; i < c; ++i) {
auto [u, v] = root[i];
ans += max(dp(u, 0), dp(v, 0));
}
printf("%lld\n", ans);
return 0;
}