Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

310. Minimum Height Trees - Leetcode Solution

Problem Description

The Minimum Height Trees problem asks you to find all possible roots of an undirected tree such that the tree has the minimum possible height.
Given an integer n representing the number of nodes labeled from 0 to n-1, and a list of undirected edges edges where each edge is a pair [u, v], your task is to return a list of all nodes that can be the root of a tree with the smallest height.

  • The input graph is guaranteed to be a tree (i.e., connected and acyclic).
  • There can be multiple roots resulting in the same minimum height.
  • The output should be a list of node labels.

Constraints:

  • 1 <= n <= 2 * 10^4
  • edges.length == n - 1
  • 0 <= u, v < n
  • u != v
  • All pairs [u, v] are unique (no duplicate edges).

Thought Process

At first glance, it may seem that you could try every node as a root and compute the height of the resulting tree, then pick those with the minimum height. However, this brute-force approach would be very inefficient, especially for large trees, since calculating the height for each node as root is expensive.

To optimize, we need to recognize that a tree's height is minimized when the root is located in the "center" of the tree. Think of the tree as a network of nodes – the center is the node (or nodes) that minimizes the farthest distance to any leaf. In graph theory, these are called the tree centers.

Instead of trying every possibility, we can iteratively remove all leaves (nodes with only one connection) layer by layer, similar to peeling an onion, until only one or two nodes remain. These are the centers of the tree and are the roots of the Minimum Height Trees.

Solution Approach

Let's break down the optimized algorithm step by step:

  1. Build the adjacency list:
    • Construct a graph where each node knows its neighbors. This allows for efficient lookup and removal of edges.
  2. Identify all leaves:
    • Leaves are nodes with only one neighbor (degree 1).
  3. Peel off leaves layer by layer:
    • At each iteration, remove the current leaves from the graph.
    • After removing, some of their neighbors may become new leaves (since their degree drops to 1).
    • Repeat this process, updating the list of leaves each time.
  4. Stop when 1 or 2 nodes remain:
    • These nodes are the centers. A tree can have at most two centers.
  5. Return the remaining node(s):
    • These are the roots of the Minimum Height Trees.

This approach is efficient because each edge and node is visited and removed at most once, leading to a linear time solution.

Example Walkthrough

Example:
n = 6, edges = [[0,1],[0,2],[0,3],[3,4],[4,5]]

  1. Build the graph:
    • 0: [1, 2, 3]
    • 1: [0]
    • 2: [0]
    • 3: [0, 4]
    • 4: [3, 5]
    • 5: [4]
  2. Find initial leaves:
    • Nodes with only one neighbor: 1, 2, 5
  3. First round of removal:
    • Remove 1, 2, 5. Their neighbors (0 and 4) lose connections.
    • After update: 0: [3], 3: [0, 4], 4: [3]
  4. New leaves:
    • 0 and 4 (now have degree 1)
  5. Second round of removal:
    • Remove 0 and 4. Their neighbor 3 is left.
    • After update: 3: []
  6. Only node 3 remains.
    • Return [3] as the root of Minimum Height Tree.

If there were two nodes left at the end, both would be returned.

Time and Space Complexity

  • Brute-force:
    • For each node, compute the height by BFS/DFS: O(n^2) in the worst case.
    • Impractical for large n.
  • Optimized (Peeling leaves):
    • Each node and edge is processed at most once: O(n).
    • Space for adjacency list and leaves: O(n).

The optimized solution is efficient and scalable for large trees.

Code Implementation

from collections import deque, defaultdict

class Solution:
    def findMinHeightTrees(self, n, edges):
        if n == 1:
            return [0]
        graph = defaultdict(set)
        for u, v in edges:
            graph[u].add(v)
            graph[v].add(u)
        leaves = [i for i in range(n) if len(graph[i]) == 1]
        remaining = n
        while remaining > 2:
            remaining -= len(leaves)
            new_leaves = []
            for leaf in leaves:
                neighbor = graph[leaf].pop()
                graph[neighbor].remove(leaf)
                if len(graph[neighbor]) == 1:
                    new_leaves.append(neighbor)
            leaves = new_leaves
        return leaves
      
#include <vector>
#include <unordered_set>
using namespace std;

class Solution {
public:
    vector<int> findMinHeightTrees(int n, vector<vector<int>>& edges) {
        if (n == 1) return {0};
        vector<unordered_set<int>> graph(n);
        for (auto& edge : edges) {
            graph[edge[0]].insert(edge[1]);
            graph[edge[1]].insert(edge[0]);
        }
        vector<int> leaves;
        for (int i = 0; i < n; ++i)
            if (graph[i].size() == 1) leaves.push_back(i);
        int remaining = n;
        while (remaining > 2) {
            remaining -= leaves.size();
            vector<int> new_leaves;
            for (int leaf : leaves) {
                int neighbor = *graph[leaf].begin();
                graph[neighbor].erase(leaf);
                if (graph[neighbor].size() == 1)
                    new_leaves.push_back(neighbor);
            }
            leaves = new_leaves;
        }
        return leaves;
    }
};
      
import java.util.*;

class Solution {
    public List<Integer> findMinHeightTrees(int n, int[][] edges) {
        if (n == 1) return Collections.singletonList(0);
        List<Set<Integer>> graph = new ArrayList<>();
        for (int i = 0; i < n; ++i) graph.add(new HashSet<>());
        for (int[] edge : edges) {
            graph.get(edge[0]).add(edge[1]);
            graph.get(edge[1]).add(edge[0]);
        }
        List<Integer> leaves = new ArrayList<>();
        for (int i = 0; i < n; ++i)
            if (graph.get(i).size() == 1) leaves.add(i);
        int remaining = n;
        while (remaining > 2) {
            remaining -= leaves.size();
            List<Integer> newLeaves = new ArrayList<>();
            for (int leaf : leaves) {
                int neighbor = graph.get(leaf).iterator().next();
                graph.get(neighbor).remove(leaf);
                if (graph.get(neighbor).size() == 1)
                    newLeaves.add(neighbor);
            }
            leaves = newLeaves;
        }
        return leaves;
    }
}
      
var findMinHeightTrees = function(n, edges) {
    if (n === 1) return [0];
    const graph = Array.from({length: n}, () => new Set());
    for (const [u, v] of edges) {
        graph[u].add(v);
        graph[v].add(u);
    }
    let leaves = [];
    for (let i = 0; i < n; ++i) {
        if (graph[i].size === 1) leaves.push(i);
    }
    let remaining = n;
    while (remaining > 2) {
        remaining -= leaves.length;
        const newLeaves = [];
        for (const leaf of leaves) {
            const neighbor = Array.from(graph[leaf])[0];
            graph[neighbor].delete(leaf);
            if (graph[neighbor].size === 1) newLeaves.push(neighbor);
        }
        leaves = newLeaves;
    }
    return leaves;
};
      

Summary

The Minimum Height Trees problem is elegantly solved by leveraging the properties of trees and centers. Rather than brute-forcing every root, we iteratively remove leaves, converging on the central node(s) that minimize the tree's height. This approach is efficient, intuitive, and scales well even for large trees, making it a great example of how understanding the structure of a problem leads to optimal solutions.