Created
April 21, 2021 12:31
-
-
Save montali/eee9a607d98622a453bb042361ef8938 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class InitialPointShapeException(Exception): | |
pass | |
class NoSimplexDefinedException(Exception): | |
pass | |
class NelderMead: | |
def __init__(self, n, fn, sum_constraint, reflection_parameter=1, expansion_parameter=2, contraction_parameter=0.5, shrinkage_parameter=0.5): | |
self.reflection_parameter = reflection_parameter | |
self.expansion_parameter = expansion_parameter | |
self.contraction_parameter = contraction_parameter | |
self.shrinkage_parameter = shrinkage_parameter | |
self.n = n | |
self.fn = fn | |
self.sum_constraint = sum_constraint | |
def initialize_simplex(self, x_1=None): | |
"""Initializes the first simplex to begin iterations | |
Args: | |
x_1 (np.array, optional): used as the first point for the simplex generation. Defaults to None, which becomes a random point. | |
Raises: | |
InitialPointShapeException: Raised when the provided first point has the wrong number of dimensions. | |
""" | |
# First, if no initial point was provided, we'll get a random one | |
self.simplex_points = np.empty((self.n+1, self.n)) | |
# If the user provided a point, and it is not in the right shape | |
if x_1 != None and x_1.shape != (self.n): | |
raise InitialPointShapeException( | |
f"Please enter an initial point having {self.n} dimensions.") | |
elif x_1 == None: # If the user didn't provide a point | |
# Multiply it by 10 so that we get numbers from 0 to 10 | |
self.simplex_points[0] = np.random.rand(self.n) | |
else: # If the user provided a point, and it is in the right shape | |
self.simplex_points[0] = x_1 | |
# Then, we will generate all the other points | |
for i in range(1, self.n+1): # The simplex has n+1 points | |
shift_coefficient = 0.05 if self.simplex_points[0][i - | |
1] != 0 else 0.0025 | |
unit_vector_i = np.zeros(self.n) | |
unit_vector_i[i-1] = 1 | |
self.simplex_points[i] = self.simplex_points[0] + \ | |
shift_coefficient * unit_vector_i | |
print(f"Succesfully initialized first simplex: {self.simplex_points}") | |
def sort(self): | |
""" | |
Fills self.simplex_points with the function values, then | |
returns the worst, second worst and best points. | |
Returns: | |
- tuple: Worst, second best and best indices of the simplex points' values | |
""" | |
# Calculate values of the function in all points of the simplex | |
self.simplex_vals = np.array( | |
self.fn(self.simplex_points.transpose())) | |
sorted_indices = np.argsort(self.simplex_vals) | |
self.min = self.simplex_vals[sorted_indices[0]] | |
return sorted_indices[0], sorted_indices[-2], sorted_indices[-1] | |
def iterate(self): | |
"""Performs one iteration of the Nelder-Mead method: | |
- Sorts the simplex points | |
- Computes the centroid | |
- Tries reflection, expansion, contraction, shrinking | |
- Updates the simplex | |
""" | |
best, sec_worst, worst = self.sort() | |
# Compute the centroid, excluding the worst point | |
centroid = np.mean(np.delete(self.simplex_points, worst), axis=0) | |
# Transformation: reflection | |
x_reflected = centroid + \ | |
(self.reflection_parameter * (centroid-self.simplex_points[worst])) | |
y_reflected = self.fn(x_reflected) | |
# If the new point is better than the second worst, but worse than the best, we can break to the next iteration | |
if self.simplex_vals[best] < y_reflected <= self.simplex_vals[sec_worst]: | |
# We don't want negative points | |
self.simplex_points[worst] = x_reflected if x_reflected > 0 else 0 | |
print("✨ Reflected ✨") | |
return | |
# If the point we've found is better than the best, we try to expand it | |
elif y_reflected < self.simplex_vals[best]: | |
x_expanded = centroid + self.expansion_parameter * \ | |
(x_reflected-centroid) | |
y_expanded = self.fn(x_expanded) | |
# We substitute the worst point with the better of the two | |
if y_expanded < y_reflected: | |
self.simplex_points[worst] = x_expanded if x_expanded > 0 else 0 | |
print("✨ Tried expansion and it worked! ✨") | |
else: | |
self.simplex_points[worst] = x_reflected if x_reflected > 0 else 0 | |
print("✨ Tried expansion but reflection was better ✨") | |
return | |
# If the point we've found was worse than the second worst, we'll contract | |
elif y_reflected > self.simplex_vals[sec_worst]: | |
x_contracted = centroid + self.contraction_parameter * \ | |
(self.simplex_points[worst] - centroid) | |
y_contracted = self.fn(x_contracted) | |
if y_contracted < self.simplex_vals[worst]: | |
self.simplex_points[worst] = x_contracted if x_contracted > 0 else 0 | |
print("✨ Contracted ✨") | |
return | |
# If none of the previous methods worked, we'll try our last resort: shrink contraction | |
# We'll want to redefine all the simplex points except for the best one. | |
for i in range(self.n+1): | |
if (i != best): # We won't change the best one | |
self.simplex_points[i] = self.simplex_points[best] + self.shrinkage_parameter * ( | |
self.simplex_points[i] - self.simplex_points[best]) | |
print("✨ Shrinked ✨") | |
def fix(self): | |
"""Reduces the simplex points' size to satisfy the constraint | |
""" | |
self.simplex_points = ( | |
self.simplex_points / np.sum(self.simplex_points, axis=1, keepdims=1)) * self.sum_constraint | |
def fit(self, target_stddev): | |
"""Computes until the STD deviation of the function values in the simplex reaches a given value | |
Args: | |
target_stddev (float, optional): Target standard deviation | |
Returns: | |
tuple: point of maximum X and its value | |
""" | |
# Check if simplex points has been defined, i.e. initialize_simplex has been called | |
if type(self.simplex_points) is not np.ndarray: | |
raise NoSimplexDefinedException | |
self.simplex_vals = np.array( | |
self.fn(self.simplex_points.transpose())) | |
std_dev = np.std(self.simplex_vals) | |
i = 0 | |
print(std_dev) | |
while std_dev > target_stddev and i < 50: | |
self.iterate() | |
std_dev = np.std(self.simplex_vals) | |
print( | |
f"🚀 Performing iteration {i}\t🥴 Standard deviation={round(std_dev, 2)}\t🏅 Value={round(self.min, 3)}") | |
i += 1 | |
self.fix() | |
_, _, best = self.sort() | |
return self.simplex_points[best] | |
if __name__ == '__main__': | |
def fn(x): return ((x[0]+2*x[1]-7)**2 + (2*x[0]+x[1]-5)**2) | |
def fn2(x): return (x[0]**2 + x[1]**2 + 4*x[3]**2 - x[4]**2 - x[5]**4) | |
nm = NelderMead(6, fn2, 1, reflection_parameter=4, expansion_parameter=4, | |
contraction_parameter=0.05, shrinkage_parameter=0.05) | |
nm.initialize_simplex() | |
print(nm.fit(0.00001)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment