Skip to content

Instantly share code, notes, and snippets.

@pilotgsm
Created June 5, 2016 13:44
Show Gist options
  • Save pilotgsm/7d473a64d7b4bed24c10069540caecef to your computer and use it in GitHub Desktop.
Save pilotgsm/7d473a64d7b4bed24c10069540caecef to your computer and use it in GitHub Desktop.
from sklearn.tree import DecisionTreeRegressor
import pandas as pd
import numpy as np
train = pd.read_csv('train.csv').set_index('_id')
X = np.array(train.drop(['gain'], 1))
y = np.array(train['gain'])
model = DecisionTreeRegressor(max_depth=7)
model.fit(X, y)
def get_prediction(row):
x = np.array(row).reshape(1, -1)
prediction = model.predict(x)[0]
return prediction
test = pd.read_csv('test.csv').set_index('_id')
test['prediction'] = test.apply(get_prediction, 1)
test[['prediction']].to_csv('submission.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment