# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def countPairs(self, root: TreeNode, distance: int) -> int:
self.result = 0
def dfs(node):
if not node:
return []
if not node.left and not node.right:
return [1]
left_distances = dfs(node.left)
right_distances = dfs(node.right)
# Count pairs
for l in left_distances:
for r in right_distances:
if l + r <= distance:
self.result += 1
# Return distances incremented by 1
return [n + 1 for n in left_distances + right_distances if n + 1 <= distance]
dfs(root)
return self.result
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
int result = 0;
vector<int> dfs(TreeNode* node, int distance) {
if (!node) return {};
if (!node->left && !node->right) return {1};
vector<int> left = dfs(node->left, distance);
vector<int> right = dfs(node->right, distance);
for (int l : left) {
for (int r : right) {
if (l + r <= distance) result++;
}
}
vector<int> res;
for (int n : left) if (n + 1 <= distance) res.push_back(n + 1);
for (int n : right) if (n + 1 <= distance) res.push_back(n + 1);
return res;
}
int countPairs(TreeNode* root, int distance) {
dfs(root, distance);
return result;
}
};
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
int result = 0;
public int countPairs(TreeNode root, int distance) {
dfs(root, distance);
return result;
}
private List<Integer> dfs(TreeNode node, int distance) {
if (node == null) return new ArrayList<>();
if (node.left == null && node.right == null) {
List<Integer> leaves = new ArrayList<>();
leaves.add(1);
return leaves;
}
List<Integer> left = dfs(node.left, distance);
List<Integer> right = dfs(node.right, distance);
for (int l : left) {
for (int r : right) {
if (l + r <= distance) result++;
}
}
List<Integer> res = new ArrayList<>();
for (int n : left) if (n + 1 <= distance) res.add(n + 1);
for (int n : right) if (n + 1 <= distance) res.add(n + 1);
return res;
}
}
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* @param {TreeNode} root
* @param {number} distance
* @return {number}
*/
var countPairs = function(root, distance) {
let result = 0;
function dfs(node) {
if (!node) return [];
if (!node.left && !node.right) return [1];
let left = dfs(node.left);
let right = dfs(node.right);
for (let l of left) {
for (let r of right) {
if (l + r <= distance) result++;
}
}
let res = [];
for (let n of left.concat(right)) {
if (n + 1 <= distance) res.push(n + 1);
}
return res;
}
dfs(root);
return result;
};
You are given the root
of a binary tree and an integer distance
. A pair of leaf nodes is considered a "good leaf node pair" if the shortest path between them (in number of edges) is less than or equal to distance
.
Your task is to count the number of good leaf node pairs in the tree. Each pair must consist of distinct leaves, and you cannot reuse the same pair in reverse order (i.e., (A, B) and (B, A) are the same and should be counted once).
Constraints:
distance
≤ 10The first instinct might be to find all leaf nodes and compute the distance between every possible pair. However, this brute-force approach would be very slow for large trees, as the number of leaf pairs can be quadratic.
Since the problem is about pairs of leaf nodes with a path length constraint, we can leverage the tree's structure. If we process the tree from the bottom up (post-order), we can, for each subtree, know how far each leaf is from the current node. This allows us to efficiently combine information from the left and right subtrees to count valid pairs without explicitly generating all possible pairs across the entire tree.
By focusing on distances from leaves to their ancestors, we can avoid redundant work and keep our solution efficient.
We'll use a recursive post-order traversal (DFS) to solve the problem efficiently. Here are the step-by-step details:
distance
, increment the result counter.
distance
, since larger distances can't form valid pairs higher up.
This approach ensures that each pair is counted exactly once, and avoids unnecessary recomputation. By only tracking distances up to distance
, we keep the space and time usage manageable.
Let's look at a simple example:
1 / \ 2 3 / / \ 4 5 6
Suppose distance = 3
. The leaf nodes are 4, 5, and 6.
This shows how pairs are only counted in their lowest common ancestor, and why post-order traversal is effective.
Brute-force approach: If we collect all leaves and check every pair, the time complexity is O(L2), where L is the number of leaves. For large trees, this can be very slow.
Optimized approach (post-order DFS):
distance
counts from left and right (since distance ≤ 10, this is constant).This problem is efficiently solved by leveraging the recursive structure of binary trees and focusing on distances from leaves to their ancestors. By using post-order traversal and only tracking relevant distances, we avoid the inefficiency of brute-force pairwise comparisons. The key insight is that each valid pair is counted exactly once at their lowest common ancestor, making the solution both elegant and efficient.