Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1530. Number of Good Leaf Nodes Pairs - Leetcode Solution

Code Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def countPairs(self, root: TreeNode, distance: int) -> int:
        self.result = 0

        def dfs(node):
            if not node:
                return []
            if not node.left and not node.right:
                return [1]
            left_distances = dfs(node.left)
            right_distances = dfs(node.right)
            # Count pairs
            for l in left_distances:
                for r in right_distances:
                    if l + r <= distance:
                        self.result += 1
            # Return distances incremented by 1
            return [n + 1 for n in left_distances + right_distances if n + 1 <= distance]
        
        dfs(root)
        return self.result
      
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int result = 0;
    vector<int> dfs(TreeNode* node, int distance) {
        if (!node) return {};
        if (!node->left && !node->right) return {1};
        vector<int> left = dfs(node->left, distance);
        vector<int> right = dfs(node->right, distance);
        for (int l : left) {
            for (int r : right) {
                if (l + r <= distance) result++;
            }
        }
        vector<int> res;
        for (int n : left) if (n + 1 <= distance) res.push_back(n + 1);
        for (int n : right) if (n + 1 <= distance) res.push_back(n + 1);
        return res;
    }
    int countPairs(TreeNode* root, int distance) {
        dfs(root, distance);
        return result;
    }
};
      
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    int result = 0;
    public int countPairs(TreeNode root, int distance) {
        dfs(root, distance);
        return result;
    }
    private List<Integer> dfs(TreeNode node, int distance) {
        if (node == null) return new ArrayList<>();
        if (node.left == null && node.right == null) {
            List<Integer> leaves = new ArrayList<>();
            leaves.add(1);
            return leaves;
        }
        List<Integer> left = dfs(node.left, distance);
        List<Integer> right = dfs(node.right, distance);
        for (int l : left) {
            for (int r : right) {
                if (l + r <= distance) result++;
            }
        }
        List<Integer> res = new ArrayList<>();
        for (int n : left) if (n + 1 <= distance) res.add(n + 1);
        for (int n : right) if (n + 1 <= distance) res.add(n + 1);
        return res;
    }
}
      
/**
 * Definition for a binary tree node.
 * function TreeNode(val, left, right) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.left = (left===undefined ? null : left)
 *     this.right = (right===undefined ? null : right)
 * }
 */
/**
 * @param {TreeNode} root
 * @param {number} distance
 * @return {number}
 */
var countPairs = function(root, distance) {
    let result = 0;
    function dfs(node) {
        if (!node) return [];
        if (!node.left && !node.right) return [1];
        let left = dfs(node.left);
        let right = dfs(node.right);
        for (let l of left) {
            for (let r of right) {
                if (l + r <= distance) result++;
            }
        }
        let res = [];
        for (let n of left.concat(right)) {
            if (n + 1 <= distance) res.push(n + 1);
        }
        return res;
    }
    dfs(root);
    return result;
};
      

Problem Description

You are given the root of a binary tree and an integer distance. A pair of leaf nodes is considered a "good leaf node pair" if the shortest path between them (in number of edges) is less than or equal to distance.

Your task is to count the number of good leaf node pairs in the tree. Each pair must consist of distinct leaves, and you cannot reuse the same pair in reverse order (i.e., (A, B) and (B, A) are the same and should be counted once).

Constraints:

  • Each node has 0, 1, or 2 children.
  • There is exactly one valid binary tree structure.
  • Do not count pairs where a leaf is paired with itself.
  • 1 ≤ distance ≤ 10
  • The number of nodes in the tree is between 1 and 2 * 104.

Thought Process

The first instinct might be to find all leaf nodes and compute the distance between every possible pair. However, this brute-force approach would be very slow for large trees, as the number of leaf pairs can be quadratic.

Since the problem is about pairs of leaf nodes with a path length constraint, we can leverage the tree's structure. If we process the tree from the bottom up (post-order), we can, for each subtree, know how far each leaf is from the current node. This allows us to efficiently combine information from the left and right subtrees to count valid pairs without explicitly generating all possible pairs across the entire tree.

By focusing on distances from leaves to their ancestors, we can avoid redundant work and keep our solution efficient.

Solution Approach

We'll use a recursive post-order traversal (DFS) to solve the problem efficiently. Here are the step-by-step details:

  1. Base Case: For each leaf node, return a list (or array) containing only 1, which represents the distance from the leaf to itself (the parent will see this as distance 1).
  2. Recursive Step: For each non-leaf node, recursively collect the distances from all leaves in its left and right subtrees.
  3. Count Valid Pairs: For the current node, consider all pairs formed by one leaf from the left and one from the right. If the sum of their distances is less than or equal to distance, increment the result counter.
  4. Propagate Upward: For the parent, return a list of all leaf distances from this node, incremented by 1 (since we're one edge higher). Only include distances less than distance, since larger distances can't form valid pairs higher up.
  5. Final Result: The answer is the accumulated count after traversing the whole tree.

This approach ensures that each pair is counted exactly once, and avoids unnecessary recomputation. By only tracking distances up to distance, we keep the space and time usage manageable.

Example Walkthrough

Let's look at a simple example:

        1
       / \
      2   3
     /   / \
    4   5   6
  

Suppose distance = 3. The leaf nodes are 4, 5, and 6.

  • From the root (1), left subtree returns [2] (leaf 4 at distance 2), right subtree returns [2,2] (leaves 5 and 6 at distance 2).
  • At node 1, we pair left and right leaves: (4,5) and (4,6). The path lengths are 2+2=4, which is greater than 3, so not counted.
  • But, in the right subtree (3), leaves 5 and 6 are from its left and right. Their distances from 3 are both 1, so 1+1=2 ≤ 3. This is a good pair.
  • So the answer is 1.

This shows how pairs are only counted in their lowest common ancestor, and why post-order traversal is effective.

Time and Space Complexity

Brute-force approach: If we collect all leaves and check every pair, the time complexity is O(L2), where L is the number of leaves. For large trees, this can be very slow.

Optimized approach (post-order DFS):

  • Each node is visited once, so the base traversal is O(N), where N is the number of nodes.
  • At each node, we combine up to distance counts from left and right (since distance ≤ 10, this is constant).
  • Thus, the total time is O(N * D2), where D is the distance, but since D is small, this is effectively linear: O(N).
  • Space complexity is O(N) for the recursion stack, plus O(D) at each node, so total space is O(N).

Summary

This problem is efficiently solved by leveraging the recursive structure of binary trees and focusing on distances from leaves to their ancestors. By using post-order traversal and only tracking relevant distances, we avoid the inefficiency of brute-force pairwise comparisons. The key insight is that each valid pair is counted exactly once at their lowest common ancestor, making the solution both elegant and efficient.