Given an integer array nums
, your task is to return a new array result
where result[i]
is the number of smaller elements to the right of nums[i]
in the original array.
result
should represent the count of elements that are strictly less than nums[i]
and appear after index i
.nums
are integers (positive, negative, or zero).
Example:
Input: nums = [5,2,6,1]
Output: [2,1,1,0]
At first glance, the problem invites a brute-force approach: for each element, scan all elements to its right and count how many are smaller. This is intuitive and easy to code, but it is inefficient for large arrays because it requires nested loops.
The challenge is to optimize this process. Instead of comparing each element to all those after it, we can try to keep track of the numbers we've already seen (from right to left) in a way that allows us to quickly count how many are less than the current number. This suggests using a data structure that supports fast insertions and fast rank queries, such as a Binary Indexed Tree (Fenwick Tree), a Segment Tree, or a Binary Search Tree.
By thinking about the problem in reverse (processing the array from the end to the start), we can build up the set of "seen" numbers and efficiently answer the question for each index as we go.
We'll use a Binary Indexed Tree (Fenwick Tree) to efficiently count the number of smaller elements for each index as we process the array from right to left.
nums
to a smaller range (e.g., 1 to N) using their sorted order. This is called coordinate compression and allows us to use the Binary Indexed Tree efficiently.
This approach is efficient because both update and query operations in the Binary Indexed Tree are O(log N), and coordinate compression keeps the tree size manageable.
Let's walk through the example nums = [5,2,6,1]
step by step:
nums[3]=1
, index=1):nums[2]=6
, index=4):nums[1]=2
, index=2):nums[0]=5
, index=3):
Final result: [2,1,1,0]
The key insight is to process the array from right to left, maintaining a data structure that allows us to quickly count how many smaller numbers we've seen so far. By using coordinate compression and a Binary Indexed Tree, we can efficiently answer the question for each element in O(log n) time, leading to a fast and scalable solution. This approach elegantly transforms a naive quadratic problem into one that can handle large inputs with ease.
class BIT:
def __init__(self, size):
self.tree = [0] * (size + 2)
def update(self, i, delta):
while i < len(self.tree):
self.tree[i] += delta
i += i & -i
def query(self, i):
s = 0
while i > 0:
s += self.tree[i]
i -= i & -i
return s
class Solution:
def countSmaller(self, nums):
# Coordinate compression
rank = {num: i+1 for i, num in enumerate(sorted(set(nums)))}
bit = BIT(len(rank))
res = []
for num in reversed(nums):
idx = rank[num]
res.append(bit.query(idx - 1))
bit.update(idx, 1)
return res[::-1]
class BIT {
public:
vector<int> tree;
BIT(int n) : tree(n + 2, 0) {}
void update(int i, int delta) {
while (i < tree.size()) {
tree[i] += delta;
i += i & -i;
}
}
int query(int i) {
int s = 0;
while (i > 0) {
s += tree[i];
i -= i & -i;
}
return s;
}
};
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
set<int> s(nums.begin(), nums.end());
unordered_map<int, int> rank;
int idx = 1;
for (int num : set<int>(s.begin(), s.end())) {
rank[num] = idx++;
}
BIT bit(rank.size());
vector<int> res(nums.size());
for (int i = nums.size() - 1; i >= 0; --i) {
int r = rank[nums[i]];
res[i] = bit.query(r - 1);
bit.update(r, 1);
}
return res;
}
};
class BIT {
int[] tree;
public BIT(int n) {
tree = new int[n + 2];
}
void update(int i, int delta) {
while (i < tree.length) {
tree[i] += delta;
i += i & -i;
}
}
int query(int i) {
int s = 0;
while (i > 0) {
s += tree[i];
i -= i & -i;
}
return s;
}
}
class Solution {
public List<Integer> countSmaller(int[] nums) {
Set<Integer> set = new TreeSet<>();
for (int num : nums) set.add(num);
Map<Integer, Integer> rank = new HashMap<>();
int idx = 1;
for (int num : set) rank.put(num, idx++);
BIT bit = new BIT(rank.size());
LinkedList<Integer> res = new LinkedList<>();
for (int i = nums.length - 1; i >= 0; --i) {
int r = rank.get(nums[i]);
res.addFirst(bit.query(r - 1));
bit.update(r, 1);
}
return res;
}
}
class BIT {
constructor(size) {
this.tree = Array(size + 2).fill(0);
}
update(i, delta) {
while (i < this.tree.length) {
this.tree[i] += delta;
i += i & -i;
}
}
query(i) {
let s = 0;
while (i > 0) {
s += this.tree[i];
i -= i & -i;
}
return s;
}
}
var countSmaller = function(nums) {
let sorted = Array.from(new Set(nums)).sort((a,b) => a-b);
let rank = new Map();
for (let i = 0; i < sorted.length; ++i) rank.set(sorted[i], i+1);
let bit = new BIT(sorted.length);
let res = [];
for (let i = nums.length - 1; i >= 0; --i) {
let idx = rank.get(nums[i]);
res.push(bit.query(idx - 1));
bit.update(idx, 1);
}
return res.reverse();
};