Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1339. Maximum Product of Splitted Binary Tree - Leetcode Solution

Problem Description

Given the root of a binary tree, you need to split the tree into two subtrees by removing exactly one edge. After the split, the product of the sums of the values of the nodes in each subtree is calculated. Your task is to find the maximum possible product of these two subtree sums, and return the answer modulo 10^9 + 7.

  • Each node in the tree has an integer value.
  • You must remove exactly one edge, resulting in two non-empty subtrees.
  • Return the maximum product you can get from any such split.
  • Constraints:
    • The number of nodes in the tree is in the range [2, 5 * 10^4].
    • Each node's value is in the range [1, 10^4].

Thought Process

The challenge is to maximize the product of the sums of two resulting subtrees after splitting the tree by removing one edge. The brute-force approach would be to try every possible way to split the tree, compute the sums of the two resulting subtrees, and keep track of the maximum product. However, this would be inefficient for large trees.

Instead, we can optimize by recognizing that:

  • The sum of all nodes in the tree (let's call it totalSum) is fixed.
  • For every subtree, if we know its sum (subSum), then removing the edge above its root would split the tree into two parts: the subtree itself and the rest of the tree, whose sum is totalSum - subSum.
  • The product for this split is subSum * (totalSum - subSum).
  • Therefore, if we can compute the sum of every possible subtree, we can easily compute the product for each possible split, and take the maximum.

This leads us to a two-pass traversal solution: first, calculate the total sum; then, for each subtree, calculate the product as above.

Solution Approach

To solve this problem efficiently, we use a depth-first search (DFS) traversal to compute subtree sums and maximize the product:

  1. Compute the total sum:
    • Perform a DFS to add up all node values in the tree. This gives us totalSum.
  2. Compute all subtree sums and maximize the product:
    • Do another DFS traversal. At each node, calculate the sum of the subtree rooted at that node.
    • For each subtree (except the entire tree), compute the product: subSum * (totalSum - subSum).
    • Keep track of the maximum product found.
  3. Return the answer modulo 10^9 + 7:
    • Since the product could be very large, return the answer modulo 10^9 + 7.

This approach ensures that each node is visited only a constant number of times, resulting in an efficient solution.

Example Walkthrough

Let's consider a simple binary tree:

        1
       / \
      2   3
  
  • Step 1: Compute the total sum: 1 + 2 + 3 = 6.
  • Step 2: Compute subtree sums:
    • Subtree rooted at 2: sum is 2.
    • Subtree rooted at 3: sum is 3.
    • Subtree rooted at 1 (whole tree): sum is 6.
  • Step 3: Consider removing edges:
    • Remove edge between 1 and 2: subtrees are [2] and [1,3]. Product = 2 * (6-2) = 2 * 4 = 8.
    • Remove edge between 1 and 3: subtrees are [3] and [1,2]. Product = 3 * (6-3) = 3 * 3 = 9.
  • Step 4: The maximum product is 9.

Time and Space Complexity

  • Brute-force approach:
    • Would involve recalculating subtree sums for every possible split: O(N^2) time.
  • Optimized approach (DFS):
    • Each node is visited twice: once for total sum, once for subtree sums. So, O(N) time.
    • Space is O(H) for recursion stack (H is the height of the tree), and O(N) if we store all subtree sums (but in practice, we only need to keep the max).

Summary

The key insight is to realize that for every possible split, the product can be calculated efficiently if we know all subtree sums. By performing two simple DFS traversals, we can solve the problem in linear time. This approach avoids redundant work and leverages properties of tree structure, making the solution both efficient and elegant.

Code Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def maxProduct(self, root: Optional[TreeNode]) -> int:
        MOD = 10**9 + 7
        total_sum = 0
        subtree_sums = []

        def get_total_sum(node):
            if not node:
                return 0
            return node.val + get_total_sum(node.left) + get_total_sum(node.right)

        def get_subtree_sums(node):
            if not node:
                return 0
            left = get_subtree_sums(node.left)
            right = get_subtree_sums(node.right)
            curr_sum = node.val + left + right
            subtree_sums.append(curr_sum)
            return curr_sum

        total_sum = get_total_sum(root)
        get_subtree_sums(root)
        max_product = 0
        for s in subtree_sums[:-1]:  # exclude the sum of the whole tree
            max_product = max(max_product, s * (total_sum - s))
        return max_product % MOD
      
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    long totalSum = 0;
    vector subtreeSums;
    int MOD = 1e9 + 7;

    long getTotalSum(TreeNode* node) {
        if (!node) return 0;
        return node->val + getTotalSum(node->left) + getTotalSum(node->right);
    }

    long getSubtreeSums(TreeNode* node) {
        if (!node) return 0;
        long left = getSubtreeSums(node->left);
        long right = getSubtreeSums(node->right);
        long currSum = node->val + left + right;
        subtreeSums.push_back(currSum);
        return currSum;
    }

    int maxProduct(TreeNode* root) {
        totalSum = getTotalSum(root);
        getSubtreeSums(root);
        long maxProduct = 0;
        // Exclude the last sum (whole tree)
        for (int i = 0; i < subtreeSums.size() - 1; ++i) {
            long s = subtreeSums[i];
            maxProduct = max(maxProduct, s * (totalSum - s));
        }
        return maxProduct % MOD;
    }
};
      
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    private long totalSum = 0;
    private List<Long> subtreeSums = new ArrayList<>();
    private final int MOD = 1_000_000_007;

    private long getTotalSum(TreeNode node) {
        if (node == null) return 0;
        return node.val + getTotalSum(node.left) + getTotalSum(node.right);
    }

    private long getSubtreeSums(TreeNode node) {
        if (node == null) return 0;
        long left = getSubtreeSums(node.left);
        long right = getSubtreeSums(node.right);
        long currSum = node.val + left + right;
        subtreeSums.add(currSum);
        return currSum;
    }

    public int maxProduct(TreeNode root) {
        totalSum = getTotalSum(root);
        getSubtreeSums(root);
        long maxProduct = 0;
        for (int i = 0; i < subtreeSums.size() - 1; i++) {
            long s = subtreeSums.get(i);
            maxProduct = Math.max(maxProduct, s * (totalSum - s));
        }
        return (int)(maxProduct % MOD);
    }
}
      
/**
 * Definition for a binary tree node.
 * function TreeNode(val, left, right) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.left = (left===undefined ? null : left)
 *     this.right = (right===undefined ? null : right)
 * }
 */

/**
 * @param {TreeNode} root
 * @return {number}
 */
var maxProduct = function(root) {
    const MOD = 1e9 + 7;
    let totalSum = 0;
    let subtreeSums = [];

    function getTotalSum(node) {
        if (!node) return 0;
        return node.val + getTotalSum(node.left) + getTotalSum(node.right);
    }

    function getSubtreeSums(node) {
        if (!node) return 0;
        let left = getSubtreeSums(node.left);
        let right = getSubtreeSums(node.right);
        let currSum = node.val + left + right;
        subtreeSums.push(currSum);
        return currSum;
    }

    totalSum = getTotalSum(root);
    getSubtreeSums(root);
    let maxProduct = 0;
    for (let i = 0; i < subtreeSums.length - 1; i++) {
        let s = subtreeSums[i];
        maxProduct = Math.max(maxProduct, s * (totalSum - s));
    }
    return maxProduct % MOD;
};