Skip to content

Instantly share code, notes, and snippets.

@slaypni
Last active September 24, 2021 17:35
Show Gist options
  • Save slaypni/b95cb69fd1c82ca4c2ff to your computer and use it in GitHub Desktop.
Save slaypni/b95cb69fd1c82ca4c2ff to your computer and use it in GitHub Desktop.
A wrapper class of XGBoost for scikit-learn
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import math
import numpy as np
sys.path.append('xgboost/wrapper/')
import xgboost as xgb
class XGBoostClassifier():
def __init__(self, num_boost_round=10, **params):
self.clf = None
self.num_boost_round = num_boost_round
self.params = params
self.params.update({'objective': 'multi:softprob'})
def fit(self, X, y, num_boost_round=None):
num_boost_round = num_boost_round or self.num_boost_round
self.label2num = dict((label, i) for i, label in enumerate(sorted(set(y))))
dtrain = xgb.DMatrix(X, label=[self.label2num[label] for label in y])
self.clf = xgb.train(params=self.params, dtrain=dtrain, num_boost_round=num_boost_round)
def predict(self, X):
num2label = dict((i, label)for label, i in self.label2num.items())
Y = self.predict_proba(X)
y = np.argmax(Y, axis=1)
return np.array([num2label[i] for i in y])
def predict_proba(self, X):
dtest = xgb.DMatrix(X)
return self.clf.predict(dtest)
def score(self, X, y):
Y = self.predict_proba(X)
return 1 / logloss(y, Y)
def get_params(self, deep=True):
return self.params
def set_params(self, **params):
if 'num_boost_round' in params:
self.num_boost_round = params.pop('num_boost_round')
if 'objective' in params:
del params['objective']
self.params.update(params)
return self
def logloss(y_true, Y_pred):
label2num = dict((name, i) for i, name in enumerate(sorted(set(y_true))))
return -1 * sum(math.log(y[label2num[label]]) if y[label2num[label]] > 0 else -np.inf for y, label in zip(Y_pred, y_true)) / len(Y_pred)
Copy link

ghost commented Oct 20, 2015

TODO: regression

@waryak
Copy link

waryak commented Oct 4, 2016

I'am sorry, but does this make it able to get non-strict(proba) predictions from xgBoost?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment