Created
November 18, 2020 09:15
-
-
Save vittorio-nardone/716832a9abbafb7526ff863909c42119 to your computer and use it in GitHub Desktop.
Metaflow steps to perform hyperparameters tuning in Prophet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@step | |
def hyper_tuning(self): | |
""" | |
Hyperparameters tuning | |
""" | |
# Tune hyperparameters of the model | |
param_grid = { | |
'changepoint_prior_scale': [0.001, 0.01, 0.1, 0.5], | |
'seasonality_prior_scale': [0.01, 0.1, 1.0, 10.0], | |
} | |
# Generate all combinations of parameters | |
self.all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())] | |
# Use cross validation to evaluate all parameters | |
self.next(self.cross_validation, foreach='all_params') | |
@step | |
def cross_validation(self): | |
""" | |
Perform cross-validation on given hyperparameters | |
""" | |
# Fit model with given params | |
m = Prophet(**self.input).fit(self.df) | |
# Perform cross-validation | |
df_cv = cross_validation(m, initial='730 days', period='180 days', horizon = '365 days', parallel="processes") | |
df_p = performance_metrics(df_cv, rolling_window=1) | |
# Store the RMSE | |
self.rmses = df_p['rmse'].values[0] | |
self.next(self.train) | |
@step | |
def train(self, inputs): | |
""" | |
Check cross-validation results and find best parameters. | |
A new Prophet model is fitted. | |
""" | |
# Merge artifacts | |
self.merge_artifacts(inputs, exclude=['rmses']) | |
# Get RMSEs from previous steps | |
rmses = [input.rmses for input in inputs] | |
# Find the best parameters | |
self.hyperparameters = self.all_params[np.argmin(rmses)] | |
# Fit a new model using best params | |
self.m = Prophet(**self.hyperparameters) | |
self.m.fit(self.df) | |
self.next(self.end) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment