Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

378. Kth Smallest Element in a Sorted Matrix - Leetcode Solution

Code Implementation

import heapq

class Solution:
    def kthSmallest(self, matrix, k):
        n = len(matrix)
        min_heap = []
        for r in range(min(k, n)):
            heapq.heappush(min_heap, (matrix[r][0], r, 0))
        for _ in range(k - 1):
            val, r, c = heapq.heappop(min_heap)
            if c + 1 < n:
                heapq.heappush(min_heap, (matrix[r][c + 1], r, c + 1))
        return min_heap[0][0]
      
#include <queue>
#include <vector>
using namespace std;

class Solution {
public:
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        auto cmp = [&matrix](const pair<int,int>& a, const pair<int,int>& b) {
            return matrix[a.first][a.second] > matrix[b.first][b.second];
        };
        priority_queue<pair<int,int>, vector<pair<int,int>>, decltype(cmp)> minHeap(cmp);
        for (int r = 0; r < min(k, n); ++r) {
            minHeap.emplace(r, 0);
        }
        int val = 0;
        for (int i = 0; i < k; ++i) {
            auto [r, c] = minHeap.top(); minHeap.pop();
            val = matrix[r][c];
            if (c + 1 < n) {
                minHeap.emplace(r, c + 1);
            }
        }
        return val;
    }
};
      
import java.util.*;

class Solution {
    public int kthSmallest(int[][] matrix, int k) {
        int n = matrix.length;
        PriorityQueue<int[]> minHeap = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        for (int r = 0; r < Math.min(k, n); ++r) {
            minHeap.offer(new int[]{matrix[r][0], r, 0});
        }
        for (int i = 0; i < k - 1; ++i) {
            int[] top = minHeap.poll();
            int val = top[0], r = top[1], c = top[2];
            if (c + 1 < n) {
                minHeap.offer(new int[]{matrix[r][c + 1], r, c + 1});
            }
        }
        return minHeap.peek()[0];
    }
}
      
class MinHeap {
    constructor() {
        this.heap = [];
    }
    push(val) {
        this.heap.push(val);
        this._bubbleUp();
    }
    pop() {
        if (this.heap.length === 1) return this.heap.pop();
        const top = this.heap[0];
        this.heap[0] = this.heap.pop();
        this._bubbleDown();
        return top;
    }
    peek() {
        return this.heap[0];
    }
    _bubbleUp() {
        let idx = this.heap.length - 1;
        while (idx > 0) {
            let parent = Math.floor((idx - 1) / 2);
            if (this.heap[parent][0] <= this.heap[idx][0]) break;
            [this.heap[parent], this.heap[idx]] = [this.heap[idx], this.heap[parent]];
            idx = parent;
        }
    }
    _bubbleDown() {
        let idx = 0;
        let length = this.heap.length;
        while (true) {
            let left = 2 * idx + 1, right = 2 * idx + 2, smallest = idx;
            if (left < length && this.heap[left][0] < this.heap[smallest][0]) smallest = left;
            if (right < length && this.heap[right][0] < this.heap[smallest][0]) smallest = right;
            if (smallest === idx) break;
            [this.heap[smallest], this.heap[idx]] = [this.heap[idx], this.heap[smallest]];
            idx = smallest;
        }
    }
}

var kthSmallest = function(matrix, k) {
    let n = matrix.length;
    let heap = new MinHeap();
    for (let r = 0; r < Math.min(k, n); ++r) {
        heap.push([matrix[r][0], r, 0]);
    }
    for (let i = 0; i < k - 1; ++i) {
        let [val, r, c] = heap.pop();
        if (c + 1 < n) {
            heap.push([matrix[r][c + 1], r, c + 1]);
        }
    }
    return heap.peek()[0];
};
      

Problem Description

Given an n x n matrix where each of the rows and columns is sorted in ascending order, find the kth smallest element in the matrix. The matrix may contain duplicate values, but each element is considered separately in the ordering.

  • You must return the exact kth smallest element (not its index or position).
  • Each element can be used only once for counting; do not reuse elements.
  • It is guaranteed that k is always valid: 1 ≤ k ≤ n^2.
  • The matrix is always square and each row and each column is sorted.

Example:
Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13

Thought Process

At first glance, the most direct approach is to flatten the entire matrix into a single list, sort it, and then pick the kth smallest element. This works because the matrix is sorted by rows and columns, so sorting the flattened list gives the correct order.

However, this brute-force method is not efficient for large matrices because:

  • Flattening takes O(n^2) time and space.
  • Sorting takes O(n^2 \log n) time.
Knowing that both rows and columns are already sorted, we should try to exploit this structure to find a faster solution.

The key insight is to realize that the smallest elements are always in the top-left of the matrix, and as we move right or down, elements become larger. This suggests we can use a min-heap to always extract the next smallest element efficiently, similar to how we might merge k sorted lists.

Solution Approach

Let's break down the optimized approach step by step:

  1. Initialize a Min-Heap:
    • Push the first element of each row into the min-heap. Each entry in the heap should store the value and its coordinates (row, column).
    • Since the matrix is sorted by rows and columns, the smallest elements are in the first column of each row.
  2. Extract the Smallest Elements:
    • Pop the smallest element from the heap. This is the next smallest in the matrix.
    • For every element popped, push the next element in the same row (i.e., move right in that row) into the heap, if it exists.
    • Repeat this k-1 times. After k-1 pops, the top of the heap is the kth smallest element.
  3. Why a Heap?
    • The heap always gives us the next smallest element efficiently (O(\log k) time per operation).
    • We never need to process more than k rows at a time, so the heap size stays manageable.

Summary of Steps:

  • Insert the first element of each row into the heap.
  • Pop the smallest and push its right neighbor (if any).
  • Repeat until you've popped k elements.
  • The last popped (or top of heap) is your answer.

Example Walkthrough

Let's use the sample input:
matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8

  1. Initial Heap: Insert first element of each row:
    • (1, 0, 0), (10, 1, 0), (12, 2, 0)
    • Heap: [1, 10, 12]
  2. Iteration 1: Pop 1 (row 0, col 0). Push 5 (row 0, col 1).
    • Heap: [5, 12, 10]
  3. Iteration 2: Pop 5 (row 0, col 1). Push 9 (row 0, col 2).
    • Heap: [9, 12, 10]
  4. Iteration 3: Pop 9 (row 0, col 2). No right neighbor.
    • Heap: [10, 12]
  5. Iteration 4: Pop 10 (row 1, col 0). Push 11 (row 1, col 1).
    • Heap: [11, 12]
  6. Iteration 5: Pop 11 (row 1, col 1). Push 13 (row 1, col 2).
    • Heap: [12, 13]
  7. Iteration 6: Pop 12 (row 2, col 0). Push 13 (row 2, col 1).
    • Heap: [13, 13]
  8. Iteration 7: Pop 13 (row 1, col 2). No right neighbor.
    • Heap: [13]
  9. Now, after 7 pops, the top of the heap is 13.
    • This is the 8th smallest element.

Time and Space Complexity

Brute-force Approach:

  • Time: O(n^2 \log n) (flatten and sort the entire matrix)
  • Space: O(n^2) (for the flattened list)
Heap-based Optimized Approach:
  • Time: O(k \log n) (each heap operation is O(\log n), and we do up to k pops)
  • Space: O(n) (the heap holds at most n elements at any time)

The optimized approach is much more efficient, especially for large matrices and small k.

Summary

By leveraging the sorted properties of the matrix, we avoid unnecessary sorting of all elements. Using a min-heap allows us to always access the next smallest candidate efficiently, similar to merging sorted lists. This approach is both elegant and highly efficient, making it suitable for even very large matrices. The key insights are:

  • Always push and pop from the heap the next smallest elements from each row.
  • Never need to process more than n rows at a time.
  • Heap operations keep our solution efficient and clean.
This technique is a great example of how recognizing data structure properties (sorted rows/columns) can lead to dramatic performance improvements.