Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

436. Find Right Interval - Leetcode Solution

Problem Description

Given a list of intervals, where each interval intervals[i] = [start_i, end_i], you are asked to find, for every interval, the right interval. A right interval for interval i is defined as the interval j such that intervals[j].start is the smallest start point that is greater than or equal to intervals[i].end, and j is not equal to i. If no such interval exists, return -1 for that interval.

The function should return an array result where result[i] is the index of the right interval for interval i, or -1 if it does not exist.

  • Each interval's start is unique.
  • There is at most one valid right interval for each interval.
  • Do not reuse intervals for their own right interval.

Thought Process

The initial idea is to, for each interval, scan all other intervals and find the one whose start is the smallest but not less than the current interval's end. This is a brute-force approach and would work, but is inefficient for large lists because it checks every pair, resulting in a time complexity of O(n2).

However, since all interval starts are unique, we can take advantage of sorting and binary search. If we sort all intervals by their start value, we can quickly find the right interval for each interval using binary search, reducing the time spent per interval from O(n) to O(log n).

The key insight is to pre-process the intervals to allow fast lookups for the next suitable interval, just like looking up a word in a dictionary.

Solution Approach

  • Step 1: Record Original Indices
    • Since we need to return the indices of the right intervals, keep track of each interval's original index. Store tuples like (start, index).
  • Step 2: Sort Intervals by Start
    • Sort the list of (start, index) pairs by start. This allows us to use binary search to quickly find the smallest start that is not less than a given end.
  • Step 3: For Each Interval, Use Binary Search
    • For each interval, perform a binary search on the sorted list to find the interval whose start is the smallest value greater than or equal to end of the current interval.
    • The result is the index of this found interval; if not found, use -1.
  • Step 4: Build Result Array
    • Store the found indices in a result array corresponding to the original interval order.

We use binary search because it allows us to find the right interval in O(log n) time for each interval. Sorting the intervals by start makes this possible.

Example Walkthrough

Suppose intervals = [[1,2], [2,3], [0,1]].

  1. Step 1: Record Indices
    We pair each interval with its index: [(1,0), (2,1), (0,2)].
  2. Step 2: Sort by Start
    After sorting: [(0,2), (1,0), (2,1)].
  3. Step 3: For Each Interval
    • For [1,2] (end = 2):
      Binary search for start >= 2. Found (2,1). So, result[0] = 1.
    • For [2,3] (end = 3):
      Binary search for start >= 3. Not found. So, result[1] = -1.
    • For [0,1] (end = 1):
      Binary search for start >= 1. Found (1,0). So, result[2] = 0.
  4. Final Output: [1, -1, 0]

This shows how sorting and binary search quickly lead us to the correct right intervals.

Time and Space Complexity

  • Brute-force Approach:
    • For each interval, scan all others: O(n2) time.
    • Space: O(n) for the result array.
  • Optimized Approach (Sorting + Binary Search):
    • Sorting takes O(n log n) time.
    • For each of n intervals, binary search takes O(log n), so total O(n log n).
    • Total time: O(n log n).
    • Space: O(n) for the sorted list and result array.

The optimized approach is much faster and suitable for large input sizes.

Summary

By sorting the intervals by their start times and using binary search, we efficiently find the right interval for each interval in O(n log n) time. The key insight is to pre-process the intervals for fast lookup, avoiding the inefficiency of brute-force scanning. This approach leverages both sorting and binary search, making it elegant and well-suited for large datasets.

Code Implementation

from bisect import bisect_left

class Solution:
    def findRightInterval(self, intervals):
        n = len(intervals)
        # Store (start, original index)
        starts = sorted((interval[0], i) for i, interval in enumerate(intervals))
        res = []
        for interval in intervals:
            end = interval[1]
            # Binary search for the smallest start >= end
            idx = bisect_left(starts, (end,))
            if idx < n:
                res.append(starts[idx][1])
            else:
                res.append(-1)
        return res
      
#include <vector>
#include <algorithm>
using namespace std;

class Solution {
public:
    vector<int> findRightInterval(vector<vector<int>>& intervals) {
        int n = intervals.size();
        vector<pair<int, int>> starts;
        for (int i = 0; i < n; ++i) {
            starts.push_back({intervals[i][0], i});
        }
        sort(starts.begin(), starts.end());
        vector<int> res;
        for (auto& interval : intervals) {
            int end = interval[1];
            auto it = lower_bound(starts.begin(), starts.end(), make_pair(end, 0));
            if (it != starts.end()) {
                res.push_back(it->second);
            } else {
                res.push_back(-1);
            }
        }
        return res;
    }
};
      
import java.util.*;

class Solution {
    public int[] findRightInterval(int[][] intervals) {
        int n = intervals.length;
        int[][] starts = new int[n][2];
        for (int i = 0; i < n; ++i) {
            starts[i][0] = intervals[i][0];
            starts[i][1] = i;
        }
        Arrays.sort(starts, Comparator.comparingInt(a -> a[0]));
        int[] res = new int[n];
        for (int i = 0; i < n; ++i) {
            int end = intervals[i][1];
            int l = 0, r = n - 1;
            int idx = -1;
            while (l <= r) {
                int m = (l + r) / 2;
                if (starts[m][0] >= end) {
                    idx = starts[m][1];
                    r = m - 1;
                } else {
                    l = m + 1;
                }
            }
            res[i] = idx;
        }
        return res;
    }
}
      
var findRightInterval = function(intervals) {
    const n = intervals.length;
    const starts = [];
    for (let i = 0; i < n; ++i) {
        starts.push([intervals[i][0], i]);
    }
    starts.sort((a, b) => a[0] - b[0]);
    const res = [];
    for (let i = 0; i < n; ++i) {
        let end = intervals[i][1];
        let l = 0, r = n - 1, idx = -1;
        while (l <= r) {
            let m = Math.floor((l + r) / 2);
            if (starts[m][0] >= end) {
                idx = starts[m][1];
                r = m - 1;
            } else {
                l = m + 1;
            }
        }
        res.push(idx);
    }
    return res;
};