The "Merge BSTs to Create Single BST" problem asks you to merge multiple binary search trees (BSTs) into a single, valid BST. You are given a list of BSTs, each with unique node values, and you must combine them so that the resulting tree is also a valid BST and contains all the nodes from the input trees, without duplicating or omitting any node.
trees
.
null
.
At first, you might think about brute-forcing all possible ways to attach the trees together, but this quickly becomes infeasible as the number of trees grows. Instead, let's look for a more systematic approach:
The challenge is to efficiently identify where to "plug in" each tree and to verify that the merged result is still a BST.
To solve the problem efficiently, we can use the following step-by-step strategy:
This approach leverages hash maps for fast lookup and ensures that the merge process is unambiguous and efficient.
Suppose you are given these three BSTs:
2
(left: 1
, right: 3
)
5
(left: 4
)
3
(right: 6
)
This problem is elegantly solved by recognizing that the unique root can be found by exclusion, and that each tree can only be attached at a leaf with a matching value. By using hash maps for fast lookup and careful traversal, we ensure all nodes are used once and the BST property is preserved. The process is efficient and scales well, as each node is processed only once.
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def canMerge(self, trees):
root_map = {tree.val: tree for tree in trees}
child_vals = set()
for tree in trees:
if tree.left:
child_vals.add(tree.left.val)
if tree.right:
child_vals.add(tree.right.val)
# Find the root that is not a child
roots = [tree for tree in trees if tree.val not in child_vals]
if len(roots) != 1:
return None
root = roots[0]
used = set()
def merge(node, min_val, max_val):
if not node:
return True
if not (min_val < node.val < max_val):
return False
if node.val in used:
return False
used.add(node.val)
# If node is a leaf and can be replaced
if not node.left and not node.right and node.val in root_map and node is not root_map[node.val]:
subtree = root_map[node.val]
node.left = subtree.left
node.right = subtree.right
return merge(node.left, min_val, node.val) and merge(node.right, node.val, max_val)
if not merge(root, float('-inf'), float('inf')):
return None
if len(used) != sum(self.countNodes(tree) for tree in trees):
return None
return root
def countNodes(self, root):
if not root:
return 0
return 1 + self.countNodes(root.left) + self.countNodes(root.right)
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
unordered_map rootMap;
unordered_set used;
int nodeCount(TreeNode* root) {
if (!root) return 0;
return 1 + nodeCount(root->left) + nodeCount(root->right);
}
bool merge(TreeNode* node, long minVal, long maxVal, TreeNode* globalRoot) {
if (!node) return true;
if (!(minVal < node->val && node->val < maxVal)) return false;
if (used.count(node->val)) return false;
used.insert(node->val);
if (!node->left && !node->right && rootMap.count(node->val) && node != rootMap[node->val]) {
TreeNode* subtree = rootMap[node->val];
node->left = subtree->left;
node->right = subtree->right;
}
return merge(node->left, minVal, node->val, globalRoot) && merge(node->right, node->val, maxVal, globalRoot);
}
TreeNode* canMerge(vector& trees) {
rootMap.clear();
used.clear();
unordered_set childVals;
int totalNodes = 0;
for (auto tree : trees) {
rootMap[tree->val] = tree;
}
for (auto tree : trees) {
if (tree->left) childVals.insert(tree->left->val);
if (tree->right) childVals.insert(tree->right->val);
}
vector roots;
for (auto tree : trees) {
if (!childVals.count(tree->val)) roots.push_back(tree);
}
if (roots.size() != 1) return nullptr;
TreeNode* root = roots[0];
for (auto tree : trees) totalNodes += nodeCount(tree);
if (!merge(root, LONG_MIN, LONG_MAX, root)) return nullptr;
if (used.size() != totalNodes) return nullptr;
return root;
}
};
// Definition for a binary tree node.
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode() {}
TreeNode(int val) { this.val = val; }
TreeNode(int val, TreeNode left, TreeNode right) {
this.val = val;
this.left = left;
this.right = right;
}
}
class Solution {
Map rootMap = new HashMap<>();
Set used = new HashSet<>();
int countNodes(TreeNode root) {
if (root == null) return 0;
return 1 + countNodes(root.left) + countNodes(root.right);
}
boolean merge(TreeNode node, long minVal, long maxVal, TreeNode globalRoot) {
if (node == null) return true;
if (!(minVal < node.val && node.val < maxVal)) return false;
if (used.contains(node.val)) return false;
used.add(node.val);
if (node.left == null && node.right == null && rootMap.containsKey(node.val) && node != rootMap.get(node.val)) {
TreeNode subtree = rootMap.get(node.val);
node.left = subtree.left;
node.right = subtree.right;
}
return merge(node.left, minVal, node.val, globalRoot) && merge(node.right, node.val, maxVal, globalRoot);
}
public TreeNode canMerge(List trees) {
rootMap.clear();
used.clear();
Set childVals = new HashSet<>();
int totalNodes = 0;
for (TreeNode tree : trees) {
rootMap.put(tree.val, tree);
}
for (TreeNode tree : trees) {
if (tree.left != null) childVals.add(tree.left.val);
if (tree.right != null) childVals.add(tree.right.val);
}
List roots = new ArrayList<>();
for (TreeNode tree : trees) {
if (!childVals.contains(tree.val)) roots.add(tree);
}
if (roots.size() != 1) return null;
TreeNode root = roots.get(0);
for (TreeNode tree : trees) totalNodes += countNodes(tree);
if (!merge(root, Long.MIN_VALUE, Long.MAX_VALUE, root)) return null;
if (used.size() != totalNodes) return null;
return root;
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* @param {TreeNode[]} trees
* @return {TreeNode}
*/
var canMerge = function(trees) {
const rootMap = new Map();
const childVals = new Set();
for (const tree of trees) {
rootMap.set(tree.val, tree);
}
for (const tree of trees) {
if (tree.left) childVals.add(tree.left.val);
if (tree.right) childVals.add(tree.right.val);
}
const roots = trees.filter(tree => !childVals.has(tree.val));
if (roots.length !== 1) return null;
const root = roots[0];
let used = new Set();
function countNodes(node) {
if (!node) return 0;
return 1 + countNodes(node.left) + countNodes(node.right);
}
function merge(node, minVal, maxVal) {
if (!node) return true;
if (!(minVal < node.val && node.val < maxVal)) return false;
if (used.has(node.val)) return false;
used.add(node.val);
if (!node.left && !node.right && rootMap.has(node.val) && node !== rootMap.get(node.val)) {
const subtree = rootMap.get(node.val);
node.left = subtree.left;
node.right = subtree.right;
}
return merge(node.left, minVal, node.val) && merge(node.right, node.val, maxVal);
}
let totalNodes = 0;
for (const tree of trees) totalNodes += countNodes(tree);
if (!merge(root, -Infinity, Infinity)) return null;
if (used.size !== totalNodes) return null;
return root;
};