Given the root of a binary tree, the task is to prune the tree so that every subtree (of the given tree) not containing a node with the value 1 is removed. In other words, remove all subtrees that do not have a 1 in them. A subtree of a node is the node plus all its descendants.
The function should return the pruned tree's root. If the entire tree should be pruned (i.e., no node with value 1 exists), then return null.
Constraints:
0 or 1.root.
To solve this problem, the first idea might be to traverse the tree and, for each node, check whether its subtree contains a 1. If not, we remove that subtree. However, this suggests that we need to know about all descendants before deciding to keep or remove a node.
This naturally leads us to a post-order traversal (left, right, node): we process the children before the parent, so we can decide whether to prune a node based on its children’s values. If both children are pruned (i.e., are null) and the node's value is 0, then this node should also be pruned.
Instead of brute-forcing by re-checking every subtree repeatedly, we can optimize by using recursion to process and prune as we return up the call stack.
We use a recursive, post-order traversal to prune the tree:
null, return null (nothing to prune).
node.left and node.right.
node.left and node.right are null and node.val == 0, return null to prune this node.
1).
This approach ensures each node is visited only once, and each decision is made based on the subtree information gathered recursively.
Input:
1
/ \
0 1
/ \
0 0
0).0 and 0) have no descendants and are 0, so they are pruned (set to null).0) has no children and is 0, so it is also pruned.1) is kept because it contains a 1.1.
1
\
1
All subtrees not containing 1 are removed.
Brute-force approach: If we tried to check every subtree for 1 independently, we could end up visiting nodes multiple times, leading to O(N^2) time complexity.
Optimized recursive approach: Each node is visited only once, and the work done at each node is constant. Therefore, the time complexity is O(N), where N is the number of nodes in the tree.
The space complexity is O(H), where H is the height of the tree, due to the recursion stack. In the worst case (completely unbalanced tree), this could be O(N); for a balanced tree, it's O(log N).
The Binary Tree Pruning problem is elegantly solved using a recursive, post-order traversal. By pruning the left and right subtrees before deciding whether to keep a node, we ensure that only subtrees containing a 1 remain. This approach is efficient, visiting each node only once, and illustrates the power of recursion for tree-based problems.
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def pruneTree(self, root: TreeNode) -> TreeNode:
if not root:
return None
root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)
if root.val == 0 and not root.left and not root.right:
return None
return root
// Definition for a binary tree node.
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class Solution {
public:
TreeNode* pruneTree(TreeNode* root) {
if (!root) return nullptr;
root->left = pruneTree(root->left);
root->right = pruneTree(root->right);
if (root->val == 0 && !root->left && !root->right) {
return nullptr;
}
return root;
}
};
// Definition for a binary tree node.
public class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) { val = x; }
}
class Solution {
public TreeNode pruneTree(TreeNode root) {
if (root == null) return null;
root.left = pruneTree(root.left);
root.right = pruneTree(root.right);
if (root.val == 0 && root.left == null && root.right == null) {
return null;
}
return root;
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* @param {TreeNode} root
* @return {TreeNode}
*/
var pruneTree = function(root) {
if (!root) return null;
root.left = pruneTree(root.left);
root.right = pruneTree(root.right);
if (root.val === 0 && !root.left && !root.right) {
return null;
}
return root;
};