Skip to content

Instantly share code, notes, and snippets.

@jphall663
Last active August 9, 2018 16:25
Show Gist options
  • Save jphall663/5eba1f2d3051c8c4fb192438d6fd716e to your computer and use it in GitHub Desktop.
Save jphall663/5eba1f2d3051c8c4fb192438d6fd716e to your computer and use it in GitHub Desktop.
Grid search for H2O penalized GLM
def glm_grid(X, y, train, valid, family):
""" Wrapper function for penalized GLM with alpha and lambda search.
:param X: List of inputs.
:param y: Name of target variable.
:param train: Name of training H2OFrame.
:param valid: Name of validation H2OFrame.
:param: family: 'gaussian' for linear regression; 'binomial' for logistic.
:return: Best H2Omodel from H2OGeneralizedLinearEstimator
"""
import h2o
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.grid.grid_search import H2OGridSearch
if family == 'binomial':
train[y] = train[y].asfactor()
valid[y] = valid[y].asfactor()
h2o.init()
alpha_opts = [0.01, 0.25, 0.5, 0.99] # always keep some L2
hyper_parameters = {'alpha': alpha_opts}
# initialize grid search
grid = H2OGridSearch(
H2OGeneralizedLinearEstimator(
family=family,
lambda_search=True,
seed=12345),
hyper_params=hyper_parameters)
# train grid
grid.train(y=y,
x=X,
training_frame=train,
validation_frame=valid)
# show grid search results
print(grid.show())
best = grid.get_grid()[0]
print(best)
# plot top frame values
yhat_frame = valid.cbind(best.predict(valid))
print(yhat_frame[0:10, [y, 'predict']])
# plot sorted predictions
if family == 'gaussian'
yhat_frame_df = yhat_frame[[y, 'predict']].as_data_frame()
yhat_frame_df.sort_values(by='predict', inplace=True)
yhat_frame_df.reset_index(inplace=True, drop=True)
_ = yhat_frame_df.plot(title='Ranked Predictions Plot')
# select best model
return best
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment