Last active
May 29, 2024 16:18
-
-
Save theptrk/47bb9bff60ef411f3929cd50a7decdcb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example of a simple nurse scheduling problem.""" | |
from ortools.sat.python import cp_model | |
class Nurse: | |
def __init__(self, name, start): | |
self.name = name | |
self.start = start | |
class NursesPartialSolutionPrinter(cp_model.CpSolverSolutionCallback): | |
"""Print intermediate solutions.""" | |
def __init__(self, shifts, all_nurses, all_days, all_shift_types, limit): | |
cp_model.CpSolverSolutionCallback.__init__(self) | |
self._shifts = shifts | |
self._all_nurses = all_nurses | |
self._all_days = all_days | |
self._all_shift_types = all_shift_types | |
self._solution_count = 0 | |
self._solution_limit = limit | |
def on_solution_callback(self): | |
self._solution_count += 1 | |
print(f"Solution {self._solution_count}") | |
for d in self._all_days: | |
print(f"Day {d}") | |
for n_obj in self._all_nurses: | |
n = n_obj.name | |
is_working = False | |
for s in self._all_shift_types: | |
if self.value(self._shifts[(n, d, s)]): | |
is_working = True | |
print(f" Nurse {n} works shift {s}") | |
if not is_working: | |
print(f" Nurse {n} does not work") | |
if self._solution_count >= self._solution_limit: | |
print(f"Stop search after {self._solution_limit} solutions") | |
self.stop_search() | |
def solutionCount(self): | |
return self._solution_count | |
class Modeler: | |
def __init__(self, all_nurses, all_days, all_shift_types, coverage_needed): | |
self.model = cp_model.CpModel() | |
self.all_nurses = all_nurses | |
self.all_days = all_days | |
self.all_shift_types = all_shift_types | |
self.coverage_needed = coverage_needed | |
self.shifts = {} | |
self.create_shift_variables() | |
self.enf_employee_availability() | |
self.enf_coverage_and_one_nurse() | |
self.enf_single_shift_assigned_per_day() | |
self.enf_preventing_eve_to_day() | |
self.enf_preventing_eve_to_on_unassigned_to_day() | |
self.maximum_assigned_shifts() | |
# Creates shift variables. | |
# shifts[(n, d, s)]: nurse 'n' works shift 's' on day 'd'. | |
def create_shift_variables(self): | |
for nurse_obj in self.all_nurses: | |
n = nurse_obj.name | |
for d in self.all_days: | |
for s in self.all_shift_types: | |
self.shifts[(n, d, s)] = self.model.new_bool_var( | |
f"shift_n{n}_d{d}_s{s}" | |
) | |
# Enforce employee availability | |
# If not "on" set (n,d,s) to 0 | |
def enf_employee_availability(self): | |
for nurse_obj in self.all_nurses: | |
n = nurse_obj.name | |
for d in self.all_days: | |
# We should only assign shifts if the employee is available | |
if nurse_obj.start[d] != "on": | |
for s in self.all_shift_types: | |
self.model.add(self.shifts[(n, d, s)] == 0) | |
# Enforce: shift coverage requirements | |
# Enforce: only nurse per shift per day | |
# if shift, day should be covered | |
# - at most one nurse is assigned | |
def enf_coverage_and_one_nurse(self): | |
for d in self.all_days: | |
for s in self.all_shift_types: | |
every_shift_for_d_s = [ | |
self.shifts[(nurse_obj.name, d, s)] for nurse_obj in self.all_nurses | |
] | |
if self.coverage_needed[s][d] == 1: | |
# (n (varying), d, s) should be a single 1 in list | |
# https://or-tools.github.io/docs/pdoc/ortools/sat/python/cp_model.html#CpModel.add_exactly_one | |
# 🧨 its possible the shift cannot be filled at all | |
# Note: previous versions assumed perfect fill had this as exactly one | |
self.model.add_at_most_one(every_shift_for_d_s) | |
else: | |
# (n (varying), d, s) should sum to 0 | |
self.model.add(sum(every_shift_for_d_s) == 0) | |
# Enforce: No double booking. Single shift assigned per n,d | |
def enf_single_shift_assigned_per_day(self): | |
for nurse_obj in self.all_nurses: | |
n = nurse_obj.name | |
for d in self.all_days: | |
self.model.add_at_most_one( | |
self.shifts[(n, d, s)] for s in self.all_shift_types | |
) | |
# Enforce: Preventing evening shift -> day shift | |
def enf_preventing_eve_to_day(self): | |
for nurse_obj in self.all_nurses: | |
n = nurse_obj.name | |
# Avoid the last day as it doesn't have a 'next day' | |
for d in range(len(self.all_days) - 1): | |
for s in self.all_shift_types: | |
if s.startswith("E"): | |
tomorrows_day_shifts = [ | |
self.shifts[(n, d + 1, s_next)] | |
for s_next in self.all_shift_types | |
if not s_next.startswith("E") | |
] | |
if tomorrows_day_shifts: | |
self.model.add( | |
sum(tomorrows_day_shifts) == 0 | |
).only_enforce_if(self.shifts[(n, d, s)]) | |
# DOES NOT WORK | |
# Enforce: Preventing evening shift -> day shift | |
def enf_preventing_eve_to_on_unassigned_to_day(self): | |
for nurse_obj in self.all_nurses: | |
n = nurse_obj.name | |
# Avoid the last day as it doesn't have a 'next day' | |
for d in range(len(self.all_days) - 2): | |
for s in self.all_shift_types: | |
if s.startswith("E"): | |
tomorrows_av_and_assigned = None | |
next_is_assigned_as_day = None | |
# tomorrow is AV and unassigned | |
if nurse_obj.start[d] == "on": | |
tomorrows_av_and_assigned = [ | |
self.shifts[(n, d + 1, s_tom)] | |
for s_tom in self.all_shift_types | |
] | |
next_is_assigned_as_day = [ | |
self.shifts[(n, d + 2, s_next)] | |
for s_next in self.all_shift_types | |
if not s_next.startswith("E") | |
] | |
if tomorrows_av_and_assigned and next_is_assigned_as_day: | |
self.model.add( | |
sum(tomorrows_av_and_assigned) + sum(next_is_assigned_as_day) <= 0 | |
).only_enforce_if(self.shifts[(n, d, s)]) | |
def maximum_assigned_shifts(self): | |
shifts_filled = 0 | |
for d in self.all_days: | |
for s in self.all_shift_types: | |
if self.coverage_needed[s][d] == 1: | |
every_shift_for_d_s = [ | |
self.shifts[(nurse_obj.name, d, s)] | |
for nurse_obj in self.all_nurses | |
] | |
shifts_filled += sum(every_shift_for_d_s) | |
self.model.maximize(shifts_filled) | |
def get_data(): | |
# Data. | |
num_days = 3 | |
# test case 1: | |
# all_nurses = [ | |
# Nurse("Alice", ["on", "on", "on"]), | |
# ] | |
# coverage_needed = {"A": [0, 1, 1], "EVE-A": [1, 0, 0]} | |
# test case 2: | |
all_nurses = [ | |
Nurse("Alice", ["on", "on", "on"]), | |
Nurse("Bobby", ["on", "off", "on"]), | |
Nurse("Cindy", ["on","on", "off"]), | |
] | |
coverage_needed = {"A": [1, 1, 1], "EVE-A": [1, 1, 1], "B": [1, 1, 1]} | |
for n_obj in all_nurses: | |
assert len(n_obj.start) == num_days | |
all_days = range(num_days) | |
# Define all shift types in a given day | |
all_shift_types = list(coverage_needed.keys()) | |
return all_nurses, all_days, all_shift_types, coverage_needed | |
def main() -> None: | |
all_nurses, all_days, all_shift_types, coverage_needed = get_data() | |
modeler = Modeler(all_nurses, all_days, all_shift_types, coverage_needed) | |
# Creates the solver and solve. | |
solver = cp_model.CpSolver() | |
solver.parameters.linearization_level = 0 | |
# Enumerate all solutions. | |
solver.parameters.enumerate_all_solutions = True | |
# Display the first five solutions. | |
solution_limit = 3 | |
solution_printer = NursesPartialSolutionPrinter( | |
modeler.shifts, all_nurses, all_days, all_shift_types, solution_limit | |
) | |
solver.solve(modeler.model, solution_printer) | |
# Statistics. | |
print("\nStatistics") | |
print(f"Maximum of objective function: {solver.objective_value}\n") | |
print(f" - conflicts : {solver.num_conflicts}") | |
print(f" - branches : {solver.num_branches}") | |
print(f" - wall time : {solver.wall_time} s") | |
print(f" - solutions found: {solution_printer.solutionCount()}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment