Want Help Cracking FAANG?

(Then click this)

×
Back to Question Bank

1931. Painting a Grid With Three Different Colors - Leetcode Solution

Problem Description

You are given two integers, m and n, representing the number of rows and columns of a grid. Each cell of the grid can be painted using one of three different colors. The grid must be painted such that:

  • No two adjacent cells in the same row have the same color.
  • No two adjacent cells in the same column have the same color.
Your task is to return the total number of ways to paint the grid following the above rules, modulo 10^9 + 7.

Constraints:

  • 1 <= m <= 5
  • 1 <= n <= 1000

The problem requires finding all valid colorings for a grid of size m x n such that no two adjacent cells (either horizontally or vertically) share the same color. The answer can be large, so return it modulo 10^9 + 7.

Thought Process

At first glance, it may seem feasible to try every possible coloring of the grid and count those that satisfy the constraints. However, for large values of n, this brute-force approach is computationally infeasible.

Let's examine the constraints:

  • Each cell can be colored in 3 ways.
  • Adjacent cells (in both rows and columns) must not have the same color.
If we try to color the grid row by row, the coloring of each row is influenced only by the previous row (because of the vertical adjacency constraint). This observation suggests a dynamic programming (DP) approach, where the state depends on the coloring of the previous row.

The key insight is to:

  • Precompute all valid colorings for a single row.
  • For each possible coloring of the current row, keep track of how many ways it can follow each possible coloring of the previous row.
This reduces the problem to a DP over columns, with the state being the coloring of the previous row.

Solution Approach

Let's break down the solution step by step:

  1. Encode Row Colorings:
    • For a row of length m, generate all possible colorings where no two adjacent cells are the same color.
    • Each coloring can be represented as a tuple or integer (e.g., for m=3, (1,2,3)).
  2. Build Compatibility Map:
    • For each pair of valid row colorings, determine if they are compatible (i.e., no cell in the same column has the same color in both rows).
    • Store which colorings can follow which others.
  3. Dynamic Programming:
    • Let dp[col][pattern] be the number of ways to color up to column col with the last row coloring being pattern.
    • Initialize dp[0][pattern] to 1 for all valid row colorings.
    • For each subsequent column, update dp[col][curr] by summing over all compatible previous row colorings.
  4. Result:
    • Sum dp[n-1][pattern] for all patterns to get the total number of ways.

We use a hash map or dictionary to store DP states, as the number of valid row colorings is manageable (since m <= 5).

Example Walkthrough

Let's consider m = 2 and n = 3 (a 2x3 grid).

  1. Step 1: Generate valid row colorings
    Possible colorings for a row of size 2: (1,2), (1,3), (2,1), (2,3), (3,1), (3,2)
  2. Step 2: Build compatibility map
    For example, (1,2) is compatible with (2,1), (2,3), (3,1), (3,2) (as no cell in the same column is the same color).
  3. Step 3: DP Initialization
    For the first column, each valid coloring is a valid starting state: count = 1 for each.
  4. Step 4: DP Transition
    For the second column, for each coloring, sum the counts from all compatible previous colorings.
  5. Step 5: Final result
    After processing all columns, sum the counts for all colorings in the last column to get the answer.

For this example, the result would be 54.

Time and Space Complexity

Brute-force approach:

  • Total ways to color the grid: 3^{m \times n}
  • Checking all possibilities is infeasible for large n.
Optimized DP approach:
  • Let P be the number of valid row colorings (for m=5, P is at most a few hundred).
  • At each of n columns, we process all P patterns, and for each pattern, all compatible patterns (also up to P).
  • Time complexity: O(n \times P^2), but P is small for m \leq 5.
  • Space complexity: O(P) (since we only need to store current and previous DP states).

This is efficient and feasible for the given constraints.

Summary

The key to solving the "Painting a Grid With Three Different Colors" problem is recognizing that we can break the coloring process into row-by-row states, and use dynamic programming to efficiently count all valid configurations. By precomputing valid row patterns and their compatibilities, and iteratively updating the number of ways for each column, we avoid brute-force enumeration. This approach leverages the small grid height (m) and is fast even for large n. The elegance comes from reducing a seemingly complex 2D problem into manageable 1D DP states.

Code Implementation

MOD = 10**9 + 7

def colorTheGrid(m, n):
    from itertools import product, combinations
    # Generate all valid row colorings
    def valid_rows(m):
        def backtrack(pos, curr):
            if pos == m:
                res.append(tuple(curr))
                return
            for color in (0,1,2):
                if pos == 0 or curr[-1] != color:
                    curr.append(color)
                    backtrack(pos+1, curr)
                    curr.pop()
        res = []
        backtrack(0, [])
        return res

    valid = valid_rows(m)
    idx_map = {row: i for i, row in enumerate(valid)}
    P = len(valid)
    # Build compatibility
    compat = [[] for _ in range(P)]
    for i, a in enumerate(valid):
        for j, b in enumerate(valid):
            if all(x != y for x, y in zip(a, b)):
                compat[i].append(j)

    dp = [1] * P
    for _ in range(n-1):
        ndp = [0] * P
        for i in range(P):
            for j in compat[i]:
                ndp[j] = (ndp[j] + dp[i]) % MOD
        dp = ndp
    return sum(dp) % MOD
      
#include <vector>
#include <algorithm>
using namespace std;

class Solution {
public:
    int colorTheGrid(int m, int n) {
        const int MOD = 1e9 + 7;
        vector<vector<int>> patterns;
        vector<int> curr;
        function<void(int)> backtrack = [&](int pos) {
            if (pos == m) {
                patterns.push_back(curr);
                return;
            }
            for (int c = 0; c < 3; ++c) {
                if (pos == 0 || curr.back() != c) {
                    curr.push_back(c);
                    backtrack(pos+1);
                    curr.pop_back();
                }
            }
        };
        for (int c = 0; c < 3; ++c) {
            curr.push_back(c);
            backtrack(1);
            curr.pop_back();
        }
        int P = patterns.size();
        vector<vector<int>> compat(P);
        for (int i = 0; i < P; ++i) {
            for (int j = 0; j < P; ++j) {
                bool ok = true;
                for (int k = 0; k < m; ++k) {
                    if (patterns[i][k] == patterns[j][k]) {
                        ok = false;
                        break;
                    }
                }
                if (ok) compat[i].push_back(j);
            }
        }
        vector<int> dp(P, 1);
        for (int col = 1; col < n; ++col) {
            vector<int> ndp(P, 0);
            for (int i = 0; i < P; ++i) {
                for (int j : compat[i]) {
                    ndp[j] = (ndp[j] + dp[i]) % MOD;
                }
            }
            dp = ndp;
        }
        int res = 0;
        for (int x : dp) res = (res + x) % MOD;
        return res;
    }
};
      
import java.util.*;

class Solution {
    static final int MOD = 1000000007;
    public int colorTheGrid(int m, int n) {
        List<int[]> patterns = new ArrayList<>();
        backtrack(m, 0, new int[m], patterns);

        int P = patterns.size();
        List<List<Integer>> compat = new ArrayList<>();
        for (int i = 0; i < P; ++i) compat.add(new ArrayList<>());
        for (int i = 0; i < P; ++i) {
            for (int j = 0; j < P; ++j) {
                boolean ok = true;
                for (int k = 0; k < m; ++k) {
                    if (patterns.get(i)[k] == patterns.get(j)[k]) {
                        ok = false;
                        break;
                    }
                }
                if (ok) compat.get(i).add(j);
            }
        }
        int[] dp = new int[P];
        Arrays.fill(dp, 1);
        for (int col = 1; col < n; ++col) {
            int[] ndp = new int[P];
            for (int i = 0; i < P; ++i) {
                for (int j : compat.get(i)) {
                    ndp[j] = (ndp[j] + dp[i]) % MOD;
                }
            }
            dp = ndp;
        }
        int res = 0;
        for (int x : dp) res = (res + x) % MOD;
        return res;
    }
    private void backtrack(int m, int pos, int[] curr, List<int[]> patterns) {
        if (pos == m) {
            patterns.add(curr.clone());
            return;
        }
        for (int c = 0; c < 3; ++c) {
            if (pos == 0 || curr[pos-1] != c) {
                curr[pos] = c;
                backtrack(m, pos+1, curr, patterns);
            }
        }
    }
}
      
var colorTheGrid = function(m, n) {
    const MOD = 1e9 + 7;
    let patterns = [];
    function backtrack(pos, curr) {
        if (pos === m) {
            patterns.push(curr.slice());
            return;
        }
        for (let c = 0; c < 3; ++c) {
            if (pos === 0 || curr[pos-1] !== c) {
                curr.push(c);
                backtrack(pos+1, curr);
                curr.pop();
            }
        }
    }
    backtrack(0, []);
    let P = patterns.length;
    let compat = Array.from({length: P}, () => []);
    for (let i = 0; i < P; ++i) {
        for (let j = 0; j < P; ++j) {
            let ok = true;
            for (let k = 0; k < m; ++k) {
                if (patterns[i][k] === patterns[j][k]) {
                    ok = false;
                    break;
                }
            }
            if (ok) compat[i].push(j);
        }
    }
    let dp = Array(P).fill(1);
    for (let col = 1; col < n; ++col) {
        let ndp = Array(P).fill(0);
        for (let i = 0; i < P; ++i) {
            for (let j of compat[i]) {
                ndp[j] = (ndp[j] + dp[i]) % MOD;
            }
        }
        dp = ndp;
    }
    return dp.reduce((a, b) => (a + b) % MOD, 0);
};