Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1803. Count Pairs With XOR in a Range - Leetcode Solution

Problem Description

Given an array of integers nums and two integers low and high, your task is to count the number of distinct pairs (i, j) (with i < j) such that the bitwise XOR of nums[i] and nums[j] falls within the inclusive range [low, high].

  • Each pair consists of two different elements (no element is paired with itself).
  • Order of the pair doesn't matter (i.e., (i, j) and (j, i) are considered the same, but only i < j is counted).
  • Return the total count of valid pairs.

Constraints:

  • 1 <= nums.length <= 2*10^4
  • 0 <= nums[i] <= 2*10^4
  • 0 <= low <= high <= 2*10^4

Thought Process

At first glance, the problem seems to require checking every possible pair in the array and calculating their XOR, then counting those that fall between low and high. This brute-force approach would involve nested loops, leading to O(n2) time complexity, which is not efficient for large arrays.

To optimize, we must avoid redundant calculations and leverage data structures that allow us to quickly count how many previous numbers can form a valid XOR pair with the current number. This leads us to consider prefix structures, like a Trie, to represent the numbers in binary and efficiently count the number of valid pairs for each number.

The idea is to process the array from left to right, and for each number, count how many of the previously seen numbers would produce an XOR in the desired range. We then add the current number to the structure for future queries.

Solution Approach

To efficiently count the number of valid pairs, we use a bitwise Trie (prefix tree):

  1. Trie Structure:
    • Each node represents a bit (0 or 1) at a certain position in the integer.
    • Each node keeps track of how many numbers have passed through it (count).
  2. Counting Valid XORs:
    • For each number, we want to count how many previous numbers have an XOR with it less than or equal to high and less than low (to get the range, we count ≤ high and < low, then subtract).
    • For each number, we query the Trie twice: once with high, once with low - 1.
  3. Processing Steps:
    • Initialize an empty Trie.
    • Iterate through nums:
      • For each num, count how many previous numbers have num ^ prev in [low, high] by querying the Trie.
      • Add num to the Trie for future queries.
    • Sum up the counts for all numbers.
  4. Why Trie?
    • Because each number can be represented in binary, a Trie allows us to efficiently count how many numbers have a certain bit pattern up to a limit, enabling fast range queries for XOR values.

This approach reduces the time complexity to O(n * log M), where M is the maximum value in nums (since the Trie depth is proportional to the number of bits).

Example Walkthrough

Example: nums = [1, 4, 2, 7], low = 2, high = 6

  1. Initialize Trie as empty.
  2. Process 1:
    • No previous numbers, so no pairs.
    • Add 1 to Trie.
  3. Process 4:
    • Check how many previous numbers (1) have 4 ^ prev in [2,6].
    • 4 ^ 1 = 5, which is in [2,6] ⇒ 1 valid pair.
    • Add 4 to Trie.
  4. Process 2:
    • Check 1: 2 ^ 1 = 3 (in range).
    • Check 4: 2 ^ 4 = 6 (in range).
    • 2 valid pairs.
    • Add 2 to Trie.
  5. Process 7:
    • Check 1: 7 ^ 1 = 6 (in range).
    • Check 4: 7 ^ 4 = 3 (in range).
    • Check 2: 7 ^ 2 = 5 (in range).
    • 3 valid pairs.
    • Add 7 to Trie.
  6. Total pairs: 0 (for 1) + 1 (for 4) + 2 (for 2) + 3 (for 7) = 6

Time and Space Complexity

  • Brute-force: O(n2) time, O(1) space (just two loops, not scalable for large n).
  • Optimized Trie Approach:
    • Time: O(n * log M), with M as the largest number in nums (since each number is processed for each bit, and the number of bits is around 15 for the constraints).
    • Space: O(n * log M), as each number is stored in the Trie, each taking up to log M nodes.

The optimized approach is efficient and scalable for the input constraints.

Summary

In this problem, the naive approach of checking all pairs is infeasible for large arrays. By representing previous numbers in a binary Trie, we can efficiently count how many pairs form a valid XOR in the given range for each element. The key insight is that the Trie structure allows us to perform fast prefix-based queries, reducing the time complexity from O(n2) to O(n log M). This makes the solution both elegant and practical for large input sizes.

Code Implementation

class TrieNode:
    def __init__(self):
        self.children = {}
        self.count = 0

class Solution:
    def countPairs(self, nums, low, high):
        def insert(num):
            node = root
            for k in range(14, -1, -1):
                bit = (num >> k) & 1
                if bit not in node.children:
                    node.children[bit] = TrieNode()
                node = node.children[bit]
                node.count += 1

        def count(num, limit):
            node = root
            res = 0
            for k in range(14, -1, -1):
                if not node:
                    break
                bit_num = (num >> k) & 1
                bit_limit = (limit >> k) & 1
                if bit_limit == 1:
                    # Take both 0 and 1 if possible
                    if bit_num ^ 0 in node.children:
                        res += node.children[bit_num ^ 0].count
                    node = node.children.get(bit_num ^ 1, None)
                else:
                    node = node.children.get(bit_num ^ 0, None)
            return res

        root = TrieNode()
        ans = 0
        for num in nums:
            cnt_high = count(num, high)
            cnt_low = count(num, low - 1)
            ans += cnt_high - cnt_low
            insert(num)
        return ans
      
class TrieNode {
public:
    TrieNode* children[2];
    int count;
    TrieNode() {
        children[0] = children[1] = nullptr;
        count = 0;
    }
};

class Solution {
public:
    void insert(TrieNode* root, int num) {
        TrieNode* node = root;
        for (int k = 14; k >= 0; --k) {
            int bit = (num >> k) & 1;
            if (!node->children[bit]) node->children[bit] = new TrieNode();
            node = node->children[bit];
            node->count++;
        }
    }

    int count(TrieNode* root, int num, int limit) {
        TrieNode* node = root;
        int res = 0;
        for (int k = 14; k >= 0; --k) {
            if (!node) break;
            int bit_num = (num >> k) & 1;
            int bit_limit = (limit >> k) & 1;
            if (bit_limit == 1) {
                if (node->children[bit_num ^ 0])
                    res += node->children[bit_num ^ 0]->count;
                node = node->children[bit_num ^ 1];
            } else {
                node = node->children[bit_num ^ 0];
            }
        }
        return res;
    }

    int countPairs(vector& nums, int low, int high) {
        TrieNode* root = new TrieNode();
        int ans = 0;
        for (int num : nums) {
            ans += count(root, num, high) - count(root, num, low - 1);
            insert(root, num);
        }
        return ans;
    }
};
      
class TrieNode {
    TrieNode[] children = new TrieNode[2];
    int count = 0;
}

class Solution {
    void insert(TrieNode root, int num) {
        TrieNode node = root;
        for (int k = 14; k >= 0; --k) {
            int bit = (num >> k) & 1;
            if (node.children[bit] == null)
                node.children[bit] = new TrieNode();
            node = node.children[bit];
            node.count++;
        }
    }

    int count(TrieNode root, int num, int limit) {
        TrieNode node = root;
        int res = 0;
        for (int k = 14; k >= 0; --k) {
            if (node == null) break;
            int bit_num = (num >> k) & 1;
            int bit_limit = (limit >> k) & 1;
            if (bit_limit == 1) {
                if (node.children[bit_num ^ 0] != null)
                    res += node.children[bit_num ^ 0].count;
                node = node.children[bit_num ^ 1];
            } else {
                node = node.children[bit_num ^ 0];
            }
        }
        return res;
    }

    public int countPairs(int[] nums, int low, int high) {
        TrieNode root = new TrieNode();
        int ans = 0;
        for (int num : nums) {
            ans += count(root, num, high) - count(root, num, low - 1);
            insert(root, num);
        }
        return ans;
    }
}
      
class TrieNode {
    constructor() {
        this.children = {};
        this.count = 0;
    }
}

var countPairs = function(nums, low, high) {
    let root = new TrieNode();

    function insert(num) {
        let node = root;
        for (let k = 14; k >= 0; --k) {
            let bit = (num >> k) & 1;
            if (!(bit in node.children)) node.children[bit] = new TrieNode();
            node = node.children[bit];
            node.count += 1;
        }
    }

    function count(num, limit) {
        let node = root;
        let res = 0;
        for (let k = 14; k >= 0; --k) {
            if (!node) break;
            let bit_num = (num >> k) & 1;
            let bit_limit = (limit >> k) & 1;
            if (bit_limit === 1) {
                if ((bit_num ^ 0) in node.children)
                    res += node.children[bit_num ^ 0].count;
                node = node.children[bit_num ^ 1];
            } else {
                node = node.children[bit_num ^ 0];
            }
        }
        return res;
    }

    let ans = 0;
    for (let num of nums) {
        ans += count(num, high) - count(num, low - 1);
        insert(num);
    }
    return ans;
};