Given the root of a binary tree, you need to split the tree into two subtrees by removing exactly one edge. After the split, the product of the sums of the values of the nodes in each subtree is calculated.
Your task is to find the maximum possible product of these two subtree sums, and return the answer modulo 10^9 + 7
.
[2, 5 * 10^4]
.[1, 10^4]
.The challenge is to maximize the product of the sums of two resulting subtrees after splitting the tree by removing one edge. The brute-force approach would be to try every possible way to split the tree, compute the sums of the two resulting subtrees, and keep track of the maximum product. However, this would be inefficient for large trees.
Instead, we can optimize by recognizing that:
totalSum
) is fixed.subSum
), then removing the edge above its root would split the tree into two parts: the subtree itself and the rest of the tree, whose sum is totalSum - subSum
.subSum * (totalSum - subSum)
.This leads us to a two-pass traversal solution: first, calculate the total sum; then, for each subtree, calculate the product as above.
To solve this problem efficiently, we use a depth-first search (DFS) traversal to compute subtree sums and maximize the product:
totalSum
.subSum * (totalSum - subSum)
.10^9 + 7
:
10^9 + 7
.This approach ensures that each node is visited only a constant number of times, resulting in an efficient solution.
Let's consider a simple binary tree:
1 / \ 2 3
1 + 2 + 3 = 6
.The key insight is to realize that for every possible split, the product can be calculated efficiently if we know all subtree sums. By performing two simple DFS traversals, we can solve the problem in linear time. This approach avoids redundant work and leverages properties of tree structure, making the solution both efficient and elegant.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxProduct(self, root: Optional[TreeNode]) -> int:
MOD = 10**9 + 7
total_sum = 0
subtree_sums = []
def get_total_sum(node):
if not node:
return 0
return node.val + get_total_sum(node.left) + get_total_sum(node.right)
def get_subtree_sums(node):
if not node:
return 0
left = get_subtree_sums(node.left)
right = get_subtree_sums(node.right)
curr_sum = node.val + left + right
subtree_sums.append(curr_sum)
return curr_sum
total_sum = get_total_sum(root)
get_subtree_sums(root)
max_product = 0
for s in subtree_sums[:-1]: # exclude the sum of the whole tree
max_product = max(max_product, s * (total_sum - s))
return max_product % MOD
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
long totalSum = 0;
vector subtreeSums;
int MOD = 1e9 + 7;
long getTotalSum(TreeNode* node) {
if (!node) return 0;
return node->val + getTotalSum(node->left) + getTotalSum(node->right);
}
long getSubtreeSums(TreeNode* node) {
if (!node) return 0;
long left = getSubtreeSums(node->left);
long right = getSubtreeSums(node->right);
long currSum = node->val + left + right;
subtreeSums.push_back(currSum);
return currSum;
}
int maxProduct(TreeNode* root) {
totalSum = getTotalSum(root);
getSubtreeSums(root);
long maxProduct = 0;
// Exclude the last sum (whole tree)
for (int i = 0; i < subtreeSums.size() - 1; ++i) {
long s = subtreeSums[i];
maxProduct = max(maxProduct, s * (totalSum - s));
}
return maxProduct % MOD;
}
};
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
private long totalSum = 0;
private List<Long> subtreeSums = new ArrayList<>();
private final int MOD = 1_000_000_007;
private long getTotalSum(TreeNode node) {
if (node == null) return 0;
return node.val + getTotalSum(node.left) + getTotalSum(node.right);
}
private long getSubtreeSums(TreeNode node) {
if (node == null) return 0;
long left = getSubtreeSums(node.left);
long right = getSubtreeSums(node.right);
long currSum = node.val + left + right;
subtreeSums.add(currSum);
return currSum;
}
public int maxProduct(TreeNode root) {
totalSum = getTotalSum(root);
getSubtreeSums(root);
long maxProduct = 0;
for (int i = 0; i < subtreeSums.size() - 1; i++) {
long s = subtreeSums.get(i);
maxProduct = Math.max(maxProduct, s * (totalSum - s));
}
return (int)(maxProduct % MOD);
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* @param {TreeNode} root
* @return {number}
*/
var maxProduct = function(root) {
const MOD = 1e9 + 7;
let totalSum = 0;
let subtreeSums = [];
function getTotalSum(node) {
if (!node) return 0;
return node.val + getTotalSum(node.left) + getTotalSum(node.right);
}
function getSubtreeSums(node) {
if (!node) return 0;
let left = getSubtreeSums(node.left);
let right = getSubtreeSums(node.right);
let currSum = node.val + left + right;
subtreeSums.push(currSum);
return currSum;
}
totalSum = getTotalSum(root);
getSubtreeSums(root);
let maxProduct = 0;
for (let i = 0; i < subtreeSums.length - 1; i++) {
let s = subtreeSums[i];
maxProduct = Math.max(maxProduct, s * (totalSum - s));
}
return maxProduct % MOD;
};