1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| #include <bits/stdc++.h> using namespace std; using LL = long long; const int N = 1e5 + 1, M = 2 * N; int n, u, v, ww; int c[N], h[N], w[M], e[M], ne[M], idx; LL dist[N], sz[N], ans; void add(int u, int v, int ww) { e[++idx] = v, w[idx] = ww, ne[idx] = h[u], h[u] = idx; } int main() { scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", &c[i]); for (int i = 1; i < n; ++i) { scanf("%d%d%d", &u, &v, &ww); add(u, v, ww), add(v, u, ww); } function<LL(int, int)> dfs = [&](int u, int f) { LL cnt = c[u]; for (int i = h[u]; i; i = ne[i]) { int v = e[i]; if (v != f) { LL s = dfs(v, u); dist[u] += dist[v] + w[i] * s; cnt += s; } } return sz[u] = cnt; }; dfs(1, 0); ans = dist[1]; function<void(int, int)> dp = [&](int u, int f) { for (int i = h[u]; i; i = ne[i]) { int v = e[i]; if (v != f) { dist[v] = dist[u] + (sz[1] - sz[v] - sz[v]) * w[i]; ans = min(ans, dist[v]); dp(v, u); } } }; dp(1, 0); printf("%lld\n", ans); return 0; }
|