Last active
August 9, 2018 16:25
-
-
Save jphall663/5eba1f2d3051c8c4fb192438d6fd716e to your computer and use it in GitHub Desktop.
Grid search for H2O penalized GLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def glm_grid(X, y, train, valid, family): | |
""" Wrapper function for penalized GLM with alpha and lambda search. | |
:param X: List of inputs. | |
:param y: Name of target variable. | |
:param train: Name of training H2OFrame. | |
:param valid: Name of validation H2OFrame. | |
:param: family: 'gaussian' for linear regression; 'binomial' for logistic. | |
:return: Best H2Omodel from H2OGeneralizedLinearEstimator | |
""" | |
import h2o | |
from h2o.estimators.glm import H2OGeneralizedLinearEstimator | |
from h2o.grid.grid_search import H2OGridSearch | |
if family == 'binomial': | |
train[y] = train[y].asfactor() | |
valid[y] = valid[y].asfactor() | |
h2o.init() | |
alpha_opts = [0.01, 0.25, 0.5, 0.99] # always keep some L2 | |
hyper_parameters = {'alpha': alpha_opts} | |
# initialize grid search | |
grid = H2OGridSearch( | |
H2OGeneralizedLinearEstimator( | |
family=family, | |
lambda_search=True, | |
seed=12345), | |
hyper_params=hyper_parameters) | |
# train grid | |
grid.train(y=y, | |
x=X, | |
training_frame=train, | |
validation_frame=valid) | |
# show grid search results | |
print(grid.show()) | |
best = grid.get_grid()[0] | |
print(best) | |
# plot top frame values | |
yhat_frame = valid.cbind(best.predict(valid)) | |
print(yhat_frame[0:10, [y, 'predict']]) | |
# plot sorted predictions | |
if family == 'gaussian' | |
yhat_frame_df = yhat_frame[[y, 'predict']].as_data_frame() | |
yhat_frame_df.sort_values(by='predict', inplace=True) | |
yhat_frame_df.reset_index(inplace=True, drop=True) | |
_ = yhat_frame_df.plot(title='Ranked Predictions Plot') | |
# select best model | |
return best |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment