Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

315. Count of Smaller Numbers After Self - Leetcode Solution

Problem Description

Given an integer array nums, your task is to return a new array result where result[i] is the number of smaller elements to the right of nums[i] in the original array.

  • Each element in result should represent the count of elements that are strictly less than nums[i] and appear after index i.
  • You may assume all elements in nums are integers (positive, negative, or zero).
  • There is only one valid output for each input.
  • Do not count or reuse elements from earlier in the array.

Example:
Input: nums = [5,2,6,1]
Output: [2,1,1,0]

Thought Process

At first glance, the problem invites a brute-force approach: for each element, scan all elements to its right and count how many are smaller. This is intuitive and easy to code, but it is inefficient for large arrays because it requires nested loops.

The challenge is to optimize this process. Instead of comparing each element to all those after it, we can try to keep track of the numbers we've already seen (from right to left) in a way that allows us to quickly count how many are less than the current number. This suggests using a data structure that supports fast insertions and fast rank queries, such as a Binary Indexed Tree (Fenwick Tree), a Segment Tree, or a Binary Search Tree.

By thinking about the problem in reverse (processing the array from the end to the start), we can build up the set of "seen" numbers and efficiently answer the question for each index as we go.

Solution Approach

We'll use a Binary Indexed Tree (Fenwick Tree) to efficiently count the number of smaller elements for each index as we process the array from right to left.

  1. Coordinate Compression: Since the numbers can be large or negative, we first map all unique numbers in nums to a smaller range (e.g., 1 to N) using their sorted order. This is called coordinate compression and allows us to use the Binary Indexed Tree efficiently.
  2. Initialize BIT: We create a Binary Indexed Tree (array) large enough to cover all unique elements.
  3. Process in Reverse: Start from the last element and move to the first. For each element:
    • Query the BIT for the count of numbers less than the current number (i.e., all numbers with compressed index less than the current one).
    • Store this count in the result array at the current index.
    • Update the BIT to include the current number (so it is counted for subsequent queries).
  4. Return Result: After processing all elements, return the result array.

This approach is efficient because both update and query operations in the Binary Indexed Tree are O(log N), and coordinate compression keeps the tree size manageable.

Example Walkthrough

Let's walk through the example nums = [5,2,6,1] step by step:

  1. Coordinate Compression:
    Unique sorted numbers: [1,2,5,6]
    Mapping: 1→1, 2→2, 5→3, 6→4
  2. Initialize BIT: Size = 4 (number of unique numbers)
  3. Process from right to left:
    • i=3 (nums[3]=1, index=1):
      Query BIT for numbers less than 1 (index 0): 0
      Update BIT to add 1 at index 1.
    • i=2 (nums[2]=6, index=4):
      Query BIT for numbers less than 4 (index 3): 1 (only 1 is to the right and less)
      Update BIT to add 1 at index 4.
    • i=1 (nums[1]=2, index=2):
      Query BIT for numbers less than 2 (index 1): 1 (only 1 is to the right and less)
      Update BIT to add 1 at index 2.
    • i=0 (nums[0]=5, index=3):
      Query BIT for numbers less than 3 (index 2): 2 (1 and 2 are to the right and less)
      Update BIT to add 1 at index 3.

    Final result: [2,1,1,0]

Time and Space Complexity

  • Brute-force approach: For each element, compare with all elements to its right.
    Time complexity: O(n2)
    Space complexity: O(n) for the output array.
  • Optimized approach (using BIT):
    • Coordinate compression: O(n log n) due to sorting
    • Each BIT update/query: O(log n)
    • Total: O(n log n) time
    • Space complexity: O(n) for the BIT and output array

Summary

The key insight is to process the array from right to left, maintaining a data structure that allows us to quickly count how many smaller numbers we've seen so far. By using coordinate compression and a Binary Indexed Tree, we can efficiently answer the question for each element in O(log n) time, leading to a fast and scalable solution. This approach elegantly transforms a naive quadratic problem into one that can handle large inputs with ease.

Code Implementation

class BIT:
    def __init__(self, size):
        self.tree = [0] * (size + 2)
    def update(self, i, delta):
        while i < len(self.tree):
            self.tree[i] += delta
            i += i & -i
    def query(self, i):
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & -i
        return s

class Solution:
    def countSmaller(self, nums):
        # Coordinate compression
        rank = {num: i+1 for i, num in enumerate(sorted(set(nums)))}
        bit = BIT(len(rank))
        res = []
        for num in reversed(nums):
            idx = rank[num]
            res.append(bit.query(idx - 1))
            bit.update(idx, 1)
        return res[::-1]
      
class BIT {
public:
    vector<int> tree;
    BIT(int n) : tree(n + 2, 0) {}
    void update(int i, int delta) {
        while (i < tree.size()) {
            tree[i] += delta;
            i += i & -i;
        }
    }
    int query(int i) {
        int s = 0;
        while (i > 0) {
            s += tree[i];
            i -= i & -i;
        }
        return s;
    }
};

class Solution {
public:
    vector<int> countSmaller(vector<int>& nums) {
        set<int> s(nums.begin(), nums.end());
        unordered_map<int, int> rank;
        int idx = 1;
        for (int num : set<int>(s.begin(), s.end())) {
            rank[num] = idx++;
        }
        BIT bit(rank.size());
        vector<int> res(nums.size());
        for (int i = nums.size() - 1; i >= 0; --i) {
            int r = rank[nums[i]];
            res[i] = bit.query(r - 1);
            bit.update(r, 1);
        }
        return res;
    }
};
      
class BIT {
    int[] tree;
    public BIT(int n) {
        tree = new int[n + 2];
    }
    void update(int i, int delta) {
        while (i < tree.length) {
            tree[i] += delta;
            i += i & -i;
        }
    }
    int query(int i) {
        int s = 0;
        while (i > 0) {
            s += tree[i];
            i -= i & -i;
        }
        return s;
    }
}

class Solution {
    public List<Integer> countSmaller(int[] nums) {
        Set<Integer> set = new TreeSet<>();
        for (int num : nums) set.add(num);
        Map<Integer, Integer> rank = new HashMap<>();
        int idx = 1;
        for (int num : set) rank.put(num, idx++);
        BIT bit = new BIT(rank.size());
        LinkedList<Integer> res = new LinkedList<>();
        for (int i = nums.length - 1; i >= 0; --i) {
            int r = rank.get(nums[i]);
            res.addFirst(bit.query(r - 1));
            bit.update(r, 1);
        }
        return res;
    }
}
      
class BIT {
    constructor(size) {
        this.tree = Array(size + 2).fill(0);
    }
    update(i, delta) {
        while (i < this.tree.length) {
            this.tree[i] += delta;
            i += i & -i;
        }
    }
    query(i) {
        let s = 0;
        while (i > 0) {
            s += this.tree[i];
            i -= i & -i;
        }
        return s;
    }
}

var countSmaller = function(nums) {
    let sorted = Array.from(new Set(nums)).sort((a,b) => a-b);
    let rank = new Map();
    for (let i = 0; i < sorted.length; ++i) rank.set(sorted[i], i+1);
    let bit = new BIT(sorted.length);
    let res = [];
    for (let i = nums.length - 1; i >= 0; --i) {
        let idx = rank.get(nums[i]);
        res.push(bit.query(idx - 1));
        bit.update(idx, 1);
    }
    return res.reverse();
};