LeetCode 834. Sum of Distances in Trees
原题链接
困难
作者:
JasonSun
,
2020-01-05 12:16:44
,
所有人可见
,
阅读 672
Algorithm (DP on trees)
- Denote the tree from input $T$. Without loss of generality, pick root of $T$ to be $r.$ Also we let $T(x)$ to denote the (sub)tree rooted at $x$. The whole tree then is denoted as $T(r)$. Let $f(T(x),v)$ denote the sum of distance to node $v$ in (sub)tree $T(x)$.
- Then $f$ has the following explicit form: $$f(T(r),v)=\begin{cases}
\sum_{c\in\texttt{children}(r)}\left[f(T(c),c)+\left\lVert T(c)\right\rVert _{0}\right] & \mathrm{if\ }v=r\\\\
f(T(r),\texttt{parent}(v))-f(T(v),v)-\left\lVert T(c)\right\rVert _{0}+f(T(v),v)+N-\left\lVert T(c)\right\rVert _{0}) & \text{o.w.}
\end{cases},$$ where $\left\lVert \cdot\right\rVert _{0}$ denote the size of the tree. One could simplify the expression a bit and write $$f(T(r),v)=\begin{cases}
\sum_{c\in\texttt{children}(r)}\left[f(T(c),c)+\left\lVert T(c)\right\rVert _{0}\right] & \mathrm{if\ }v=r\\\\
f(T(r),\texttt{parent}(v))+N-2\left\lVert T(c)\right\rVert _{0}) & \text{o.w.}
\end{cases}.$$
- Both $f$ and $\left\lVert \cdot\right\rVert _{0}$ could be evaluated using usual dynamic programming techniques.
Time Complexity
- $O(N)$ for evaluation of $f$ and $\left\lVert \cdot\right\rVert .$
Memory
- $O(N)$.
Code
template <class F>
struct recursive {
F f;
template <class... Ts>
decltype(auto) operator()(Ts&&... ts) const { return f(std::ref(*this), std::forward<Ts>(ts)...); }
template <class... Ts>
decltype(auto) operator()(Ts&&... ts) { return f(std::ref(*this), std::forward<Ts>(ts)...); }
};
template <class F> recursive(F) -> recursive<F>;
auto const rec = [](auto f){ return recursive{std::move(f)}; };
class Solution {
public:
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& x) {
struct node_t {
optional<int> parent;
vector<int> children;
};
const auto graph = [&](vector<vector<int>> self = {}) {
self.resize(N, vector<int>());
for (const auto e : x) (self[e[0]].emplace_back(e[1]),
self[e[1]].emplace_back(e[0]));
return self;
}();
const auto tree = [&](vector<node_t> self = {}) {
self.resize(N, {});
auto visited = vector<bool>(N, false);
auto dfs = rec([&](auto &&dfs, int root) -> void {
visited[root] = true;
for (const auto v : graph[root])
if (not visited[v]) {
self[v].parent = root;
self[root].children.emplace_back(v);
dfs(v);
}
});
return (dfs(0), self);
}();
auto count = rec([&, memo = array<optional<int>, 10005> {}](auto&& count, int p) mutable -> int {
if (memo[p])
return *memo[p];
else
return *(memo[p] = [&] {
if (empty(tree[p].children))
return 1;
else
return [&](int acc = 1) {
for (const auto child : tree[p].children)
acc += count(child);
return acc;
}();
}());
});
auto sum = rec([&, memo = array<optional<int>, 10005> {}] (auto && sum, int p) mutable -> int {
if (memo[p])
return *memo[p];
else
return *(memo[p] = [&] {
if (empty(tree[p].children))
return 0;
else
return [&](int acc = 0) {
for (const auto child : tree[p].children)
acc += (sum(child) + count(child));
return acc;
}();
}());
});
auto f = rec([&, memo = array<optional<int>, 10005> {}](auto && f, int p) mutable -> int {
if (memo[p])
return *memo[p];
else
return *(memo[p] = [&] {
if (not tree[p].parent)
return sum(p);
else
return f(tree[p].parent.value()) - 2 * count(p) + N ;
}());
});
const auto solution = [&] (vector<int> self = {}) {
self.resize(N);
for (int i = 0; i < N; ++i) self[i] = f(i);
return self;
}();
return solution;
}
};