题目描述
给定一个二叉树,找到最长的路径,这个路径中的每个节点具有相同值。这条路径可以经过也可以不经过根节点。
注意:两个节点之间的路径长度由它们之间的边数表示。
样例
输入:
5
/ \
4 5
/ \ \
1 1 5
输出:
2
输入:
1
/ \
4 5
/ \ \
4 4 5
输出:
2
限制
- 给定的二叉树不超过
10000
个结点。树的高度不超过1000
。
算法
(深度优先遍历) $O(n)$
- 递归的返回值:当前节点权值的最长路径长度。
- 每次递归左右儿子结点。如果当前节点的值等于左儿子的权值,也等于右儿子的权值,则可以用左右儿子的返回值加 2 更新答案,并返回两个返回值中的最大值加 1。
- 其他情况可以同理分析。
时间复杂度
- 每个节点访问一次,故总时间复杂度为 $O(n)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储递归的系统栈
C++ 代码
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) :
* val(x), left(left), right(right) {}
* };
*/
class Solution {
private:
int ans;
int solve(TreeNode *rt) {
int v = rt->val;
if (rt->left && rt->right) {
int l = solve(rt->left);
int r = solve(rt->right);
if (v == rt->left->val && v == rt->right->val) {
ans = max(ans, l + r + 2);
return max(l, r) + 1;
}
if (v == rt->left->val) {
ans = max(ans, r);
ans = max(ans, l + 1);
return l + 1;
}
if (v == rt->right->val) {
ans = max(ans, l);
ans = max(ans, r + 1);
return r + 1;
}
ans = max(ans, l);
ans = max(ans, r);
return 0;
}
if (rt->left) {
int l = solve(rt->left);
if (v == rt->left->val) {
ans = max(ans, l + 1);
return l + 1;
}
ans = max(ans, l);
return 0;
}
if (rt->right) {
int r = solve(rt->right);
if (v == rt->right->val) {
ans = max(ans, r + 1);
return r + 1;
}
ans = max(ans, r);
return 0;
}
return 0;
}
public:
int longestUnivaluePath(TreeNode* root) {
if (!root) return 0;
ans = 0;
solve(root);
return ans;
}
};
C++ 代码(另一种简单实现)
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) :
* val(x), left(left), right(right) {}
* };
*/
class Solution {
private:
int ans;
int solve(TreeNode *rt) {
if (!rt) return 0;
int l = solve(rt->left);
int r = solve(rt->right);
int lv = 0, rv = 0;
if (rt->left && rt->left->val == rt->val)
lv = l + 1;
if (rt->right && rt->right->val == rt->val)
rv = r + 1;
ans = max(ans, lv + rv);
return max(lv, rv);
}
public:
int longestUnivaluePath(TreeNode* root) {
ans = 0;
solve(root);
return ans;
}
};