Last active
March 6, 2019 21:31
-
-
Save nathansutton/b8a34b6c9f310a7a3a299ae9b4b9d226 to your computer and use it in GitHub Desktop.
translate an h2o model in sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{python} | |
def translate (fit) : | |
"""return the sql translation of a H2OGeneralizedLinearEstimator""" | |
# pandas dependency | |
try: | |
import pandas.DataFrame | |
except ImportError: | |
print("pandas could not be loaded") | |
# object correct type | |
try: | |
assert type(fit) == 'h2o.estimators.glm.H2OGeneralizedLinearEstimator' | |
except AssertionError: | |
print("Only suitable H2OGeneralizedLinearEstimator") | |
# object correct family | |
try: | |
assert fit.family in ['binomial','gaussian'] | |
except Assertionerror: | |
print("Only suitable for logistic / linear regression") | |
# extract coefficients | |
df = pandas.DataFrame.from_dict(fit.coef(),orient='index') | |
df['term'] = df.index | |
df.columns = ['beta','term'] | |
# translate each row into term * beta | |
df['sql'] = '' | |
df.loc[df['term'] == 'Intercept','sql'] = '(' + df.loc[df['term'] == 'Intercept','beta'].astype(str) + ')' | |
df.loc[df['term'] != 'Intercept','sql'] = '(' + df.loc[df['term'] != 'Intercept','term'].astype(str) + ' * ' + df.loc[df['term'] != 'Intercept','beta'].astype(str) + ')' | |
# a classification model | |
if fit.family == 'binomial': | |
# logistic function | |
sql = '1.0 - 1.0 / (1.0 + EXP(' + " + ".join(str(bit) for bit in df['sql'].tolist()) + '))' | |
# otherwise a regression model | |
else: | |
sql = " + ".join(str(bit) for bit in df['sql'].tolist()) | |
return sql |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment