Python 代码
import sys
sys.setrecursionlimit(10000000)
n = int(input())
h, e, ne, cost = [-1] * (n+1), [0] * (2*n+2), [0] * (2*n+2), [0] * (n+1)
idx = 0
def add(a, b):
global idx
e[idx] = b
h[a], ne[idx] = idx, h[a]
idx += 1
def dfs(cur, pre):
nn = h[cur]
# r0 自己没被选,父节点被选,r1 自己没被选,子节点至少有一个被选,r2 自己被选
r0, r1, r2, min_off = 0, 0, cost[cur], float('inf')
while nn != -1:
son = e[nn]
if son != pre:
s0, s1, s2 = dfs(son, cur)
r0 += min(s1, s2)
r1 += min(s1, s2)
r2 += min(s0, s1, s2)
min_off = min(min_off, s2 - min(s1, s2))
nn = ne[nn]
r1 += min_off
return r0, r1, r2
for _ in range(n):
tmp = list(map(int, input().split()))
a, c, _ = tmp[0], tmp[1], tmp[2]
cost[a] = c
bs = tmp[3:]
for b in bs:
add(a, b)
add(b, a)
r0, r1, r2 = dfs(1, -1)
print(min(r1, r2))