Created
April 22, 2017 12:04
-
-
Save slchangtw/c08a4969cf640c60b7ade23ce03f6a5d to your computer and use it in GitHub Desktop.
This example shows how adaboost works
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
def get_alpha(error): | |
return 0.5 * np.log((1 - error) / error) | |
def update_weights(y, weights, g, alpha): | |
return weights * np.exp(-alpha * g * y) / sum(weights * np.exp(-alpha * g * y)) | |
# 1 feature, 10 samples | |
X = np.arange(10) | |
# label | |
y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1]) | |
# set equal weights for each sample | |
weights1 = np.empty(X.shape[0], dtype=np.float64) | |
weights1[:] = 1. / X.shape[0] | |
# when X < 2.5, we will get the smallest error | |
g1 = np.where(X < 2.5, 1, -1) | |
e1 = np.sum((y != g1) * weights) | |
alpha1 = get_alpha(e1) | |
# adjust weights | |
weights2 = update_weights(y, weights1, g1, alpha1) | |
g2 = np.where(X < 8.5, 1, -1) | |
e2 = np.sum((y != g2) * weights2) | |
alpha2 = get_alpha(e2) | |
# | |
weights3 = update_weights(y, weights2, g2, alpha2) | |
g3 = np.where(X > 5.5, 1, -1) | |
e3 = np.sum((y != g3) * weights3) | |
alpha3 = get_alpha(e3) | |
# final model | |
f = alpha1 * g1 + alpha2 * g2 + alpha3 * g3 | |
np.sign(f) == y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment