Given the root of a binary tree, the task is to prune the tree so that every subtree (of the given tree) not containing a node with the value 1
is removed. In other words, remove all subtrees that do not have a 1
in them. A subtree of a node is the node plus all its descendants.
The function should return the pruned tree's root. If the entire tree should be pruned (i.e., no node with value 1
exists), then return null
.
Constraints:
0
or 1
.root
.
To solve this problem, the first idea might be to traverse the tree and, for each node, check whether its subtree contains a 1
. If not, we remove that subtree. However, this suggests that we need to know about all descendants before deciding to keep or remove a node.
This naturally leads us to a post-order traversal (left, right, node): we process the children before the parent, so we can decide whether to prune a node based on its children’s values. If both children are pruned (i.e., are null
) and the node's value is 0
, then this node should also be pruned.
Instead of brute-forcing by re-checking every subtree repeatedly, we can optimize by using recursion to process and prune as we return up the call stack.
We use a recursive, post-order traversal to prune the tree:
null
, return null
(nothing to prune).
node.left
and node.right
.
node.left
and node.right
are null
and node.val == 0
, return null
to prune this node.
1
).
This approach ensures each node is visited only once, and each decision is made based on the subtree information gathered recursively.
Input:
1 / \ 0 1 / \ 0 0
0
).0
and 0
) have no descendants and are 0
, so they are pruned (set to null
).0
) has no children and is 0
, so it is also pruned.1
) is kept because it contains a 1
.1
.1 \ 1
All subtrees not containing 1
are removed.
Brute-force approach: If we tried to check every subtree for 1
independently, we could end up visiting nodes multiple times, leading to O(N^2) time complexity.
Optimized recursive approach: Each node is visited only once, and the work done at each node is constant. Therefore, the time complexity is O(N), where N
is the number of nodes in the tree.
The space complexity is O(H), where H
is the height of the tree, due to the recursion stack. In the worst case (completely unbalanced tree), this could be O(N); for a balanced tree, it's O(log N).
The Binary Tree Pruning problem is elegantly solved using a recursive, post-order traversal. By pruning the left and right subtrees before deciding whether to keep a node, we ensure that only subtrees containing a 1
remain. This approach is efficient, visiting each node only once, and illustrates the power of recursion for tree-based problems.
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def pruneTree(self, root: TreeNode) -> TreeNode:
if not root:
return None
root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)
if root.val == 0 and not root.left and not root.right:
return None
return root
// Definition for a binary tree node.
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class Solution {
public:
TreeNode* pruneTree(TreeNode* root) {
if (!root) return nullptr;
root->left = pruneTree(root->left);
root->right = pruneTree(root->right);
if (root->val == 0 && !root->left && !root->right) {
return nullptr;
}
return root;
}
};
// Definition for a binary tree node.
public class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) { val = x; }
}
class Solution {
public TreeNode pruneTree(TreeNode root) {
if (root == null) return null;
root.left = pruneTree(root.left);
root.right = pruneTree(root.right);
if (root.val == 0 && root.left == null && root.right == null) {
return null;
}
return root;
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* @param {TreeNode} root
* @return {TreeNode}
*/
var pruneTree = function(root) {
if (!root) return null;
root.left = pruneTree(root.left);
root.right = pruneTree(root.right);
if (root.val === 0 && !root.left && !root.right) {
return null;
}
return root;
};