Given the root of a binary tree, the task is to find the maximum path sum. A path is defined as any sequence of nodes from some starting node to any node in the tree along the parent-child connections. The path must contain at least one node and does not need to go through the root.
You are to return an integer representing the maximum path sum
among all possible paths in the tree.
At first glance, you might consider every possible path in the tree and compute the sum for each, but this would be inefficient for large trees. Instead, let's think about how paths can be constructed:
left-max + node's value + right-max
, where left-max
and right-max
are the best downward path sums from the left and right child, respectively.This insight leads us to a recursive, post-order traversal, where each node calculates the best sum it can contribute upward, and updates a global maximum if a better path is found.
null
, return 0 (no gain).node.val + left_gain + right_gain
.node.val + max(left_gain, right_gain)
.max_sum
) initialized to negative infinity.max_sum
will hold the answer.We use post-order traversal because we need to compute left and right gains before we can process the current node. The use of a global variable ensures that the maximum is updated regardless of where the optimal path occurs in the tree.
Consider the following tree:
-10 / \ 9 20 / \ 15 7
The key insight is that the maximum path sum can be found by considering each node as a potential "peak" of a path, and recursively computing the best gains from its subtrees. By using post-order traversal, we efficiently propagate the best contributions upward, while updating a global maximum whenever a better path is found. This approach is elegant, efficient, and leverages the recursive structure of trees naturally.
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def maxPathSum(self, root: TreeNode) -> int:
self.max_sum = float('-inf')
def max_gain(node):
if not node:
return 0
# Discard negative paths
left_gain = max(max_gain(node.left), 0)
right_gain = max(max_gain(node.right), 0)
# Path through this node
current_sum = node.val + left_gain + right_gain
self.max_sum = max(self.max_sum, current_sum)
# Return max gain to parent
return node.val + max(left_gain, right_gain)
max_gain(root)
return self.max_sum
// Definition for a binary tree node.
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class Solution {
public:
int max_sum = INT_MIN;
int maxGain(TreeNode* node) {
if (!node) return 0;
int left_gain = std::max(maxGain(node->left), 0);
int right_gain = std::max(maxGain(node->right), 0);
int current_sum = node->val + left_gain + right_gain;
max_sum = std::max(max_sum, current_sum);
return node->val + std::max(left_gain, right_gain);
}
int maxPathSum(TreeNode* root) {
maxGain(root);
return max_sum;
}
};
// Definition for a binary tree node.
public class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) { val = x; }
}
class Solution {
private int maxSum = Integer.MIN_VALUE;
public int maxPathSum(TreeNode root) {
maxGain(root);
return maxSum;
}
private int maxGain(TreeNode node) {
if (node == null) return 0;
int leftGain = Math.max(maxGain(node.left), 0);
int rightGain = Math.max(maxGain(node.right), 0);
int currentSum = node.val + leftGain + rightGain;
maxSum = Math.max(maxSum, currentSum);
return node.val + Math.max(leftGain, rightGain);
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
var maxPathSum = function(root) {
let maxSum = -Infinity;
function maxGain(node) {
if (node === null) return 0;
let leftGain = Math.max(maxGain(node.left), 0);
let rightGain = Math.max(maxGain(node.right), 0);
let currentSum = node.val + leftGain + rightGain;
maxSum = Math.max(maxSum, currentSum);
return node.val + Math.max(leftGain, rightGain);
}
maxGain(root);
return maxSum;
};