Given an array of integers nums
and two integers low
and high
, your task is to count the number of distinct pairs (i, j)
(with i < j
) such that the bitwise XOR of nums[i]
and nums[j]
falls within the inclusive range [low, high]
.
(i, j)
and (j, i)
are considered the same, but only i < j
is counted).Constraints:
1 <= nums.length <= 2*10^4
0 <= nums[i] <= 2*10^4
0 <= low <= high <= 2*10^4
At first glance, the problem seems to require checking every possible pair in the array and calculating their XOR, then counting those that fall between low
and high
. This brute-force approach would involve nested loops, leading to O(n2) time complexity, which is not efficient for large arrays.
To optimize, we must avoid redundant calculations and leverage data structures that allow us to quickly count how many previous numbers can form a valid XOR pair with the current number. This leads us to consider prefix structures, like a Trie, to represent the numbers in binary and efficiently count the number of valid pairs for each number.
The idea is to process the array from left to right, and for each number, count how many of the previously seen numbers would produce an XOR in the desired range. We then add the current number to the structure for future queries.
To efficiently count the number of valid pairs, we use a bitwise Trie (prefix tree):
count
).high
and less than low
(to get the range, we count ≤ high and < low, then subtract).high
, once with low - 1
.nums
:
num
, count how many previous numbers have num ^ prev
in [low, high]
by querying the Trie.num
to the Trie for future queries.
This approach reduces the time complexity to O(n * log M), where M is the maximum value in nums
(since the Trie depth is proportional to the number of bits).
Example: nums = [1, 4, 2, 7]
, low = 2
, high = 6
4 ^ prev
in [2,6].4 ^ 1 = 5
, which is in [2,6] ⇒ 1 valid pair.2 ^ 1 = 3
(in range).2 ^ 4 = 6
(in range).7 ^ 1 = 6
(in range).7 ^ 4 = 3
(in range).7 ^ 2 = 5
(in range).M
as the largest number in nums
(since each number is processed for each bit, and the number of bits is around 15 for the constraints).The optimized approach is efficient and scalable for the input constraints.
In this problem, the naive approach of checking all pairs is infeasible for large arrays. By representing previous numbers in a binary Trie, we can efficiently count how many pairs form a valid XOR in the given range for each element. The key insight is that the Trie structure allows us to perform fast prefix-based queries, reducing the time complexity from O(n2) to O(n log M). This makes the solution both elegant and practical for large input sizes.
class TrieNode:
def __init__(self):
self.children = {}
self.count = 0
class Solution:
def countPairs(self, nums, low, high):
def insert(num):
node = root
for k in range(14, -1, -1):
bit = (num >> k) & 1
if bit not in node.children:
node.children[bit] = TrieNode()
node = node.children[bit]
node.count += 1
def count(num, limit):
node = root
res = 0
for k in range(14, -1, -1):
if not node:
break
bit_num = (num >> k) & 1
bit_limit = (limit >> k) & 1
if bit_limit == 1:
# Take both 0 and 1 if possible
if bit_num ^ 0 in node.children:
res += node.children[bit_num ^ 0].count
node = node.children.get(bit_num ^ 1, None)
else:
node = node.children.get(bit_num ^ 0, None)
return res
root = TrieNode()
ans = 0
for num in nums:
cnt_high = count(num, high)
cnt_low = count(num, low - 1)
ans += cnt_high - cnt_low
insert(num)
return ans
class TrieNode {
public:
TrieNode* children[2];
int count;
TrieNode() {
children[0] = children[1] = nullptr;
count = 0;
}
};
class Solution {
public:
void insert(TrieNode* root, int num) {
TrieNode* node = root;
for (int k = 14; k >= 0; --k) {
int bit = (num >> k) & 1;
if (!node->children[bit]) node->children[bit] = new TrieNode();
node = node->children[bit];
node->count++;
}
}
int count(TrieNode* root, int num, int limit) {
TrieNode* node = root;
int res = 0;
for (int k = 14; k >= 0; --k) {
if (!node) break;
int bit_num = (num >> k) & 1;
int bit_limit = (limit >> k) & 1;
if (bit_limit == 1) {
if (node->children[bit_num ^ 0])
res += node->children[bit_num ^ 0]->count;
node = node->children[bit_num ^ 1];
} else {
node = node->children[bit_num ^ 0];
}
}
return res;
}
int countPairs(vector& nums, int low, int high) {
TrieNode* root = new TrieNode();
int ans = 0;
for (int num : nums) {
ans += count(root, num, high) - count(root, num, low - 1);
insert(root, num);
}
return ans;
}
};
class TrieNode {
TrieNode[] children = new TrieNode[2];
int count = 0;
}
class Solution {
void insert(TrieNode root, int num) {
TrieNode node = root;
for (int k = 14; k >= 0; --k) {
int bit = (num >> k) & 1;
if (node.children[bit] == null)
node.children[bit] = new TrieNode();
node = node.children[bit];
node.count++;
}
}
int count(TrieNode root, int num, int limit) {
TrieNode node = root;
int res = 0;
for (int k = 14; k >= 0; --k) {
if (node == null) break;
int bit_num = (num >> k) & 1;
int bit_limit = (limit >> k) & 1;
if (bit_limit == 1) {
if (node.children[bit_num ^ 0] != null)
res += node.children[bit_num ^ 0].count;
node = node.children[bit_num ^ 1];
} else {
node = node.children[bit_num ^ 0];
}
}
return res;
}
public int countPairs(int[] nums, int low, int high) {
TrieNode root = new TrieNode();
int ans = 0;
for (int num : nums) {
ans += count(root, num, high) - count(root, num, low - 1);
insert(root, num);
}
return ans;
}
}
class TrieNode {
constructor() {
this.children = {};
this.count = 0;
}
}
var countPairs = function(nums, low, high) {
let root = new TrieNode();
function insert(num) {
let node = root;
for (let k = 14; k >= 0; --k) {
let bit = (num >> k) & 1;
if (!(bit in node.children)) node.children[bit] = new TrieNode();
node = node.children[bit];
node.count += 1;
}
}
function count(num, limit) {
let node = root;
let res = 0;
for (let k = 14; k >= 0; --k) {
if (!node) break;
let bit_num = (num >> k) & 1;
let bit_limit = (limit >> k) & 1;
if (bit_limit === 1) {
if ((bit_num ^ 0) in node.children)
res += node.children[bit_num ^ 0].count;
node = node.children[bit_num ^ 1];
} else {
node = node.children[bit_num ^ 0];
}
}
return res;
}
let ans = 0;
for (let num of nums) {
ans += count(num, high) - count(num, low - 1);
insert(num);
}
return ans;
};