Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1825. Finding MK Average - Leetcode Solution

Problem Description

The "Finding MK Average" problem from LeetCode asks you to design a data structure called MKAverage that maintains a stream of integers and supports two main operations:

  • addElement(num): Adds the integer num to the data structure.
  • calculateMKAverage(): Calculates and returns the MK Average for the last m elements in the stream. The MK Average is defined as the average of the last m elements after removing the smallest k elements and the largest k elements. The result should be the integer part of the average (i.e., floor division).

If there are fewer than m elements in the stream, calculateMKAverage() should return -1.

Constraints:

  • Each addElement operation adds a single integer to the stream.
  • Only the last m elements are considered for the MK Average.
  • Elements removed from the stream (when the window exceeds m) must also be removed from all data structures tracking the k smallest and largest elements.
  • All operations must be efficient, as the number of operations can be large (up to 105).

Thought Process

At first glance, the problem seems straightforward: just keep adding elements, and when asked, sort the last m elements, remove the smallest and largest k, and average the rest. However, sorting for every query would be too slow for large streams.

The challenge lies in efficiently maintaining a sliding window of the last m elements, and being able to quickly remove the k smallest and largest elements for the average calculation. We need a data structure that supports:

  • Fast insertions and deletions as the window slides.
  • Quick access to the k smallest and largest elements.
  • Efficient calculation of the sum of the middle elements.

Brute force would involve sorting the window every time, which is O(m log m) per query. We need to do better.

Solution Approach

To solve the problem efficiently, we use a combination of data structures:

  • A queue to store the last m elements (for the sliding window).
  • Three balanced multisets (or similar data structures) to partition the window into:
    • The k smallest elements (left set)
    • The middle m - 2k elements (middle set)
    • The k largest elements (right set)
  • A running sum of the middle set for quick average calculation.

The algorithm works as follows:

  1. When a new element is added:
    • Insert it into the appropriate set (left, middle, or right) based on its value.
    • Rebalance the sets so that left and right always contain exactly k elements each, and the middle contains the rest.
  2. If the window exceeds size m, remove the oldest element from the window and from its corresponding set, then rebalance.
  3. When calculateMKAverage() is called, if there are fewer than m elements, return -1. Otherwise, return the integer division of the sum of the middle set by its size.

In Python, we can use the SortedList from the sortedcontainers module to efficiently maintain ordered multisets.

Example Walkthrough

Let's say m = 5 and k = 1. We perform the following operations:

  1. addElement(3) → window: [3]
  2. addElement(1) → window: [3, 1]
  3. addElement(10) → window: [3, 1, 10]
  4. addElement(5) → window: [3, 1, 10, 5]
  5. addElement(5) → window: [3, 1, 10, 5, 5]
  6. calculateMKAverage():
    • Sort: [1, 3, 5, 5, 10]
    • Remove smallest 1: [3, 5, 5, 10]
    • Remove largest 1: [3, 5, 5]
    • Average: (3 + 5 + 5) / 3 = 13 / 3 = 4 (integer division)
  7. addElement(7) → window: [1, 10, 5, 5, 7]
  8. calculateMKAverage():
    • Sort: [1, 5, 5, 7, 10]
    • Remove smallest 1: [5, 5, 7, 10]
    • Remove largest 1: [5, 5, 7]
    • Average: (5 + 5 + 7) / 3 = 17 / 3 = 5
This example shows how the window updates and how the sets are maintained for each operation.

Time and Space Complexity

Brute-force approach:

  • Each calculateMKAverage() call would sort the last m elements: O(m log m)
  • Adding an element: O(1)
  • Overall: Too slow for large m and frequent queries.
Optimized approach (using balanced multisets):
  • Each insertion/removal in a balanced multiset: O(log m)
  • Each addElement: O(log m)
  • Each calculateMKAverage(): O(1) (since we maintain the sum of the middle set)
  • Space: O(m) (we store at most m elements in the window and the sets)

This makes the optimized solution efficient enough for large inputs.

Summary

The MK Average problem is a classic example of efficient sliding window computation with order statistics. The key insight is to use three balanced multisets (or similar structures) to maintain the smallest, middle, and largest elements, allowing for quick updates and queries. By maintaining a running sum of the middle set, we avoid re-sorting or rescanning the window, leading to a highly efficient solution. This approach demonstrates the power of using the right data structures for the right job.

Code Implementation

from collections import deque
from sortedcontainers import SortedList

class MKAverage:
    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.q = deque()
        self.left = SortedList()
        self.mid = SortedList()
        self.right = SortedList()
        self.mid_sum = 0

    def addElement(self, num: int) -> None:
        self.q.append(num)
        if len(self.q) <= self.m:
            self.mid.add(num)
            self.mid_sum += num
            if len(self.q) == self.m:
                # Partition mid into left, mid, right
                while len(self.left) < self.k:
                    val = self.mid[0]
                    self.left.add(val)
                    self.mid.remove(val)
                    self.mid_sum -= val
                while len(self.right) < self.k:
                    val = self.mid[-1]
                    self.right.add(val)
                    self.mid.remove(val)
                    self.mid_sum -= val
        else:
            # Remove oldest
            old = self.q.popleft()
            # Remove old from the correct set
            if old in self.left:
                self.left.remove(old)
            elif old in self.right:
                self.right.remove(old)
            else:
                self.mid.remove(old)
                self.mid_sum -= old
            # Insert new
            if num <= self.left[-1]:
                self.left.add(num)
            elif num >= self.right[0]:
                self.right.add(num)
            else:
                self.mid.add(num)
                self.mid_sum += num
            # Rebalance
            while len(self.left) > self.k:
                val = self.left[-1]
                self.left.remove(val)
                self.mid.add(val)
                self.mid_sum += val
            while len(self.left) < self.k:
                val = self.mid[0]
                self.mid.remove(val)
                self.mid_sum -= val
                self.left.add(val)
            while len(self.right) > self.k:
                val = self.right[0]
                self.right.remove(val)
                self.mid.add(val)
                self.mid_sum += val
            while len(self.right) < self.k:
                val = self.mid[-1]
                self.mid.remove(val)
                self.mid_sum -= val
                self.right.add(val)

    def calculateMKAverage(self) -> int:
        if len(self.q) < self.m:
            return -1
        return self.mid_sum // (self.m - 2 * self.k)
      
#include <queue>
#include <set>
using namespace std;

class MKAverage {
    int m, k;
    queue<int> q;
    multiset<int> left, mid, right;
    long long mid_sum = 0;

    void balance() {
        while (left.size() < k) {
            auto it = mid.begin();
            left.insert(*it);
            mid_sum -= *it;
            mid.erase(it);
        }
        while (left.size() > k) {
            auto it = prev(left.end());
            mid.insert(*it);
            mid_sum += *it;
            left.erase(it);
        }
        while (right.size() < k) {
            auto it = prev(mid.end());
            right.insert(*it);
            mid_sum -= *it;
            mid.erase(it);
        }
        while (right.size() > k) {
            auto it = right.begin();
            mid.insert(*it);
            mid_sum += *it;
            right.erase(it);
        }
    }

public:
    MKAverage(int m, int k) : m(m), k(k) {}

    void addElement(int num) {
        q.push(num);
        if (q.size() <= m) {
            mid.insert(num);
            mid_sum += num;
            if (q.size() == m) {
                balance();
            }
        } else {
            int old = q.front(); q.pop();
            // Remove old
            if (left.count(old)) left.erase(left.find(old));
            else if (right.count(old)) right.erase(right.find(old));
            else {
                mid.erase(mid.find(old));
                mid_sum -= old;
            }
            // Insert new
            if (!left.empty() && num <= *left.rbegin()) left.insert(num);
            else if (!right.empty() && num >= *right.begin()) right.insert(num);
            else {
                mid.insert(num);
                mid_sum += num;
            }
            balance();
        }
    }

    int calculateMKAverage() {
        if (q.size() < m) return -1;
        return mid_sum / (m - 2 * k);
    }
};
      
import java.util.*;

class MKAverage {
    int m, k;
    Queue<Integer> q = new LinkedList<>();
    TreeMap<Integer, Integer> left = new TreeMap<>();
    TreeMap<Integer, Integer> mid = new TreeMap<>();
    TreeMap<Integer, Integer> right = new TreeMap<>();
    int leftSize = 0, midSize = 0, rightSize = 0;
    long midSum = 0;

    public MKAverage(int m, int k) {
        this.m = m;
        this.k = k;
    }

    private void add(TreeMap<Integer, Integer> map, int num) {
        map.put(num, map.getOrDefault(num, 0) + 1);
    }

    private void remove(TreeMap<Integer, Integer> map, int num) {
        map.put(num, map.get(num) - 1);
        if (map.get(num) == 0) map.remove(num);
    }

    private int getFirst(TreeMap<Integer, Integer> map) {
        return map.firstKey();
    }

    private int getLast(TreeMap<Integer, Integer> map) {
        return map.lastKey();
    }

    private void balance() {
        while (leftSize < k) {
            int val = getFirst(mid);
            remove(mid, val); midSize--;
            add(left, val); leftSize++;
            midSum -= val;
        }
        while (leftSize > k) {
            int val = getLast(left);
            remove(left, val); leftSize--;
            add(mid, val); midSize++;
            midSum += val;
        }
        while (rightSize < k) {
            int val = getLast(mid);
            remove(mid, val); midSize--;
            add(right, val); rightSize++;
            midSum -= val;
        }
        while (rightSize > k) {
            int val = getFirst(right);
            remove(right, val); rightSize--;
            add(mid, val); midSize++;
            midSum += val;
        }
    }

    public void addElement(int num) {
        q.offer(num);
        if (q.size() <= m) {
            add(mid, num); midSize++;
            midSum += num;
            if (q.size() == m) balance();
        } else {
            int old = q.poll();
            if (left.containsKey(old)) { remove(left, old); leftSize--; }
            else if (right.containsKey(old)) { remove(right, old); rightSize--; }
            else { remove(mid, old); midSize--; midSum -= old; }
            if (!left.isEmpty() && num <= getLast(left)) { add(left, num); leftSize++; }
            else if (!right.isEmpty() && num >= getFirst(right)) { add(right, num); rightSize++; }
            else { add(mid, num); midSize++; midSum += num; }
            balance();
        }
    }

    public int calculateMKAverage() {
        if (q.size() < m) return -1;
        return (int)(midSum / (m - 2 * k));
    }
}
      
class SortedList {
    constructor() {
        this.arr = [];
    }
    add(val) {
        let l = 0, r = this.arr.length;
        while (l < r) {
            let m = (l + r) >> 1;
            if (this.arr[m] < val) l = m + 1;
            else r = m;
        }
        this.arr.splice(l, 0, val);
    }
    remove(val) {
        let idx = this.arr.indexOf(val);
        if (idx !== -1) this.arr.splice(idx, 1);
    }
    get length() { return this.arr.length; }
    get(idx) { return this.arr[idx]; }
    first() { return this.arr[0]; }
    last() { return this.arr[this.arr.length - 1]; }
}

class MKAverage {
    constructor(m, k) {
        this.m = m;
        this.k = k;
        this.q = [];
        this.left = new SortedList();
        this.mid = new SortedList();
        this.right = new SortedList();
        this.midSum = 0;
    }

    addElement(num) {
        this.q.push(num);
        if (this.q.length <= this.m) {
            this.mid.add(num);
            this.midSum += num;
            if (this.q.length === this.m) {
                while (this.left.length < this.k) {
                    let val = this.mid.first();
                    this.left.add(val);
                    this.mid.remove(val);
                    this.midSum -= val;
                }
                while (this.right.length < this.k) {
                    let val = this.mid.last();
                    this.right.add(val);
                    this.mid.remove(val);
                    this.midSum -= val;
                }
            }
        } else {
            let old = this.q.shift();
            if (this.left.arr.includes(old)) this.left.remove(old);
            else if (this.right.arr.includes(old)) this.right.remove(old);
            else {
                this.mid.remove(old);
                this.midSum -= old;
            }
            if (this.left.length && num <= this.left.last()) this.left.add(num);
            else if (this.right.length && num >= this.right.first()) this.right.add(num);
            else {
                this.mid.add(num);
                this.midSum += num;
            }
            while (this.left.length > this.k) {
                let val = this.left.last();
                this.left.remove(val);
                this.mid.add(val);
                this.midSum += val;
            }
            while (this.left.length < this.k) {
                let val = this.mid.first();
                this.mid.remove(val);
                this.midSum -= val;
                this.left.add(val);
            }
            while (this.right.length > this.k) {
                let val = this.right.first();
                this.right.remove(val);
                this.mid.add(val);
                this.midSum += val;
            }
            while (this.right.length < this.k) {
                let val = this.mid.last();
                this.mid.remove(val);
                this.midSum -= val;
                this.right.add(val);
            }
        }
    }

    calculateMKAverage() {
        if (this.q.length < this.m) return -1;
        return Math.floor(this.midSum / (this.m - 2 * this.k));
    }
}