Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1660. Correct a Binary Tree - Leetcode Solution

Code Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def correctBinaryTree(self, root: 'TreeNode') -> 'TreeNode':
        from collections import deque
        queue = deque([root])
        seen = set()
        while queue:
            next_level = deque()
            for node in queue:
                if node.right:
                    if node.right in seen:
                        # Remove this node (set parent.left/right to None)
                        node.right = None
                        return root
                next_level.append(node)
            for node in queue:
                seen.add(node)
            queue = deque()
            for node in next_level:
                if node.left:
                    queue.append(node.left)
                if node.right:
                    queue.append(node.right)
        return root
      
// Definition for a binary tree node.
// struct TreeNode {
//     int val;
//     TreeNode *left;
//     TreeNode *right;
//     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
// };

class Solution {
public:
    TreeNode* correctBinaryTree(TreeNode* root) {
        queue<TreeNode*> q;
        unordered_set<TreeNode*> seen;
        q.push(root);
        while (!q.empty()) {
            int sz = q.size();
            vector<TreeNode*> level;
            for (int i = 0; i < sz; ++i) {
                TreeNode* node = q.front(); q.pop();
                level.push_back(node);
            }
            for (TreeNode* node : level) {
                if (node->right) {
                    if (seen.count(node->right)) {
                        node->right = nullptr;
                        return root;
                    }
                }
            }
            for (TreeNode* node : level) {
                seen.insert(node);
            }
            for (TreeNode* node : level) {
                if (node->left) q.push(node->left);
                if (node->right) q.push(node->right);
            }
        }
        return root;
    }
};
      
// Definition for a binary tree node.
// public class TreeNode {
//     int val;
//     TreeNode left;
//     TreeNode right;
//     TreeNode(int x) { val = x; }
// }

import java.util.*;

class Solution {
    public TreeNode correctBinaryTree(TreeNode root) {
        Queue<TreeNode> queue = new LinkedList<>();
        Set<TreeNode> seen = new HashSet<>();
        queue.offer(root);
        while (!queue.isEmpty()) {
            int size = queue.size();
            List<TreeNode> level = new ArrayList<>();
            for (int i = 0; i < size; ++i) {
                TreeNode node = queue.poll();
                level.add(node);
            }
            for (TreeNode node : level) {
                if (node.right != null) {
                    if (seen.contains(node.right)) {
                        node.right = null;
                        return root;
                    }
                }
            }
            for (TreeNode node : level) {
                seen.add(node);
            }
            for (TreeNode node : level) {
                if (node.left != null) queue.offer(node.left);
                if (node.right != null) queue.offer(node.right);
            }
        }
        return root;
    }
}
      
/**
 * // Definition for a binary tree node.
 * function TreeNode(val) {
 *     this.val = val;
 *     this.left = this.right = null;
 * }
 */

/**
 * @param {TreeNode} root
 * @return {TreeNode}
 */
var correctBinaryTree = function(root) {
    let queue = [root];
    let seen = new Set();
    while (queue.length) {
        let nextLevel = [];
        for (let node of queue) {
            if (node.right) {
                if (seen.has(node.right)) {
                    node.right = null;
                    return root;
                }
            }
            nextLevel.push(node);
        }
        for (let node of queue) {
            seen.add(node);
        }
        queue = [];
        for (let node of nextLevel) {
            if (node.left) queue.push(node.left);
            if (node.right) queue.push(node.right);
        }
    }
    return root;
};
      

Problem Description

You are given the root of a binary tree where exactly one node has an incorrect right child pointer. Specifically, this node's right child pointer incorrectly points to another node at the same depth but to its right, violating the binary tree property. Your task is to return the root of the tree after removing the invalid node (along with its entire subtree).

  • There is exactly one such invalid node in the tree.
  • All node values are unique.
  • You must remove the entire subtree rooted at the invalid node, not just the pointer.
  • The problem guarantees a unique solution.

Thought Process

To solve this problem, we first need to recognize the unique property: only one node's right pointer is invalid, and it points to a node on the same level but to its right. A brute-force approach might involve traversing the tree and, for every node, checking if its right pointer violates the binary tree rule. However, this can be inefficient and complex.

Instead, we notice that the invalid pointer always points to a node already visited in a level-order traversal (BFS), because it points to a node at the same level but further to the right. This insight allows us to use BFS and track visited nodes at each level, so when we see a right pointer that points to a node we've already seen, we know we've found the invalid node.

This approach is more efficient and leverages the specific structure of the problem, turning a potentially complex check into a straightforward traversal with a set to record visited nodes.

Solution Approach

We use a Breadth-First Search (BFS) traversal to solve this problem efficiently. Here are the steps:

  1. Initialize a queue with the root node for level-order traversal.
  2. Maintain a set of seen nodes (visited at the current or previous levels).
  3. For each level:
    • Process all nodes at the current level.
    • For each node, if its right child is already in the seen set, we've found the invalid node.
    • Remove the invalid node by setting the parent's left or right pointer to null (or None in Python).
    • Return the root, as the tree is now corrected.
  4. After processing the level, add all nodes from this level to the seen set.
  5. Continue the BFS to the next level using the valid children.

This approach is justified because the invalid right pointer always points to a node already visited in the same level, so the first time we see such a pointer, we know it is the one to remove.

Example Walkthrough

Consider the following tree:

        1
      /   \
     2     3
      \     \
       4     5
             /
            6
  

Suppose node 4's right pointer incorrectly points to node 5 (which is at the same level and to its right). The BFS proceeds as follows:

  1. Level 1: Visit node 1. Seen set = {1}.
  2. Level 2: Visit nodes 2 and 3. Seen set = {1,2,3}.
  3. Level 3: Visit nodes 4 and 5. Before adding to seen set, check right pointers:
    • Node 4's right pointer points to node 5, which is not in seen, so we move on.
    • Node 5's left child is node 6, which is not in seen.
  4. Now, add nodes 4 and 5 to the seen set.
  5. Level 4: Visit node 6. No right pointers to check.

However, if node 4's right pointer pointed to a node already in seen (say, node 3), we would immediately detect and remove node 4 and its subtree.

The algorithm efficiently finds and removes the invalid node in a single pass.

Time and Space Complexity

Brute-force approach: If we tried to check every node's right pointer against every other node at the same level, the time complexity could be O(N2), where N is the number of nodes.

Optimized BFS approach: Each node is visited once, and lookups in the seen set are O(1). Thus, the time complexity is O(N).

The space complexity is also O(N) due to the queue and the seen set, each of which may store up to N nodes in the worst case (e.g., a full binary tree).

Summary

The key insight is that the invalid right pointer always points to a node already visited in a level-order traversal. By using BFS and a set to track seen nodes, we can efficiently detect and remove the invalid subtree in O(N) time. This approach is both elegant and robust, leveraging the problem's guarantees for a unique solution and the properties of binary trees.