You are given n
people and a list of hats, where each person has a list of hats they are willing to wear. There are 40 different hats, labeled from 1 to 40. Each person may have a subset of these hats they like.
The goal is to determine the number of ways to assign hats to people such that:
10^9 + 7
.
Key constraints:
At first glance, you might consider generating all possible assignments of hats to people and checking which ones are valid. However, since there are up to 40 hats and each person can have many choices, this approach quickly becomes infeasible due to the combinatorial explosion.
Instead, notice:
n
) is small (≤ 10), but the number of hats is fixed (40).We'll use dynamic programming with bitmasking to efficiently represent assignment states.
mask
be a bitmask of length n
(number of people), where the i-th bit is 1 if person i has been assigned a hat.dp[mask][hat]
represent the number of ways to assign hats 1..hat
such that the assignment matches the current mask
.dp[0][0] = 1
: zero people assigned, zero hats considered.dp[FULL_MASK][40]
, where FULL_MASK
means all people have been assigned hats.Suppose there are 2 people:
mask = 0b00
(no one assigned).mask = 0b01
.mask = 0b01
(another way).mask = 0b01
mask = 0b10
mask = 0b10
mask = 0b11
(both assigned) gives the total number of valid assignments.In this case, the valid assignments are:
Brute-force:
By leveraging the small number of people and the fixed number of hats, we use dynamic programming with bitmasking to efficiently count the number of valid assignments. The key insight is to process hats one by one, updating assignment states, and using bitmasks to represent which people have already been assigned hats. This approach is both elegant and computationally efficient, avoiding the pitfalls of brute-force enumeration.
MOD = 10**9 + 7
class Solution:
def numberWays(self, hats):
n = len(hats)
hat_to_people = [[] for _ in range(41)]
for person, hat_list in enumerate(hats):
for h in hat_list:
hat_to_people[h].append(person)
dp = [0] * (1 << n)
dp[0] = 1
for hat in range(1, 41):
ndp = dp[:]
for mask in range(1 << n):
if dp[mask] == 0:
continue
for person in hat_to_people[hat]:
if not (mask & (1 << person)):
ndp[mask | (1 << person)] = (ndp[mask | (1 << person)] + dp[mask]) % MOD
dp = ndp
return dp[(1 << n) - 1]
#define MOD 1000000007
class Solution {
public:
int numberWays(vector<vector<int>>& hats) {
int n = hats.size();
vector<vector<int>> hat_to_people(41);
for (int i = 0; i < n; ++i) {
for (int h : hats[i]) {
hat_to_people[h].push_back(i);
}
}
vector<int> dp(1 << n, 0);
dp[0] = 1;
for (int hat = 1; hat <= 40; ++hat) {
vector<int> ndp = dp;
for (int mask = 0; mask < (1 << n); ++mask) {
if (dp[mask] == 0) continue;
for (int person : hat_to_people[hat]) {
if (!(mask & (1 << person))) {
ndp[mask | (1 << person)] = (ndp[mask | (1 << person)] + dp[mask]) % MOD;
}
}
}
dp = ndp;
}
return dp[(1 << n) - 1];
}
};
class Solution {
static final int MOD = 1000000007;
public int numberWays(List<List<Integer>> hats) {
int n = hats.size();
List<List<Integer>> hatToPeople = new ArrayList<>();
for (int i = 0; i <= 40; ++i) hatToPeople.add(new ArrayList<>());
for (int i = 0; i < n; ++i) {
for (int h : hats.get(i)) {
hatToPeople.get(h).add(i);
}
}
int[] dp = new int[1 << n];
dp[0] = 1;
for (int hat = 1; hat <= 40; ++hat) {
int[] ndp = dp.clone();
for (int mask = 0; mask < (1 << n); ++mask) {
if (dp[mask] == 0) continue;
for (int person : hatToPeople.get(hat)) {
if ((mask & (1 << person)) == 0) {
ndp[mask | (1 << person)] = (ndp[mask | (1 << person)] + dp[mask]) % MOD;
}
}
}
dp = ndp;
}
return dp[(1 << n) - 1];
}
}
const MOD = 1e9 + 7;
var numberWays = function(hats) {
const n = hats.length;
const hatToPeople = Array.from({length: 41}, () => []);
for (let i = 0; i < n; ++i) {
for (const h of hats[i]) {
hatToPeople[h].push(i);
}
}
let dp = new Array(1 << n).fill(0);
dp[0] = 1;
for (let hat = 1; hat <= 40; ++hat) {
let ndp = dp.slice();
for (let mask = 0; mask < (1 << n); ++mask) {
if (dp[mask] == 0) continue;
for (const person of hatToPeople[hat]) {
if ((mask & (1 << person)) === 0) {
ndp[mask | (1 << person)] = (ndp[mask | (1 << person)] + dp[mask]) % MOD;
}
}
}
dp = ndp;
}
return dp[(1 << n) - 1];
};