Skip to content

Instantly share code, notes, and snippets.

@agramfort
Forked from fabianp/gist:3097107
Created July 26, 2012 09:24
Show Gist options
  • Save agramfort/3181189 to your computer and use it in GitHub Desktop.
Save agramfort/3181189 to your computer and use it in GitHub Desktop.
strong rules lasso
# -*- coding: utf-8 -*-
"""
Generalized linear models via coordinate descent
Author: Fabian Pedregosa <fabian@fseoane.net>
"""
import numpy as np
MAX_ITER = 100
def l1_coordinate_descent(X, y, alpha, warm_start=None, max_iter=MAX_ITER):
if warm_start is not None:
beta = warm_start
else:
beta = np.zeros(X.shape[1], dtype=np.float)
alpha = alpha * X.shape[0]
for _ in range(max_iter):
for i in range(X.shape[1]):
bb = beta.copy()
bb[i] = 0.
residual = np.dot(X[:, i], y - np.dot(X, bb).T)
beta[i] = np.sign(residual) * np.fmax(np.abs(residual) - alpha, 0) \
/ np.dot(X[:, i], X[:, i])
return beta
def shrinkage(X, y, alpha, beta, active_set, max_iter):
bb = beta.copy()
for _ in range(max_iter):
for i in active_set:
bb[i] = 0
residual = np.dot(X[:, i], y - np.dot(X, bb).T)
bb[i] = np.sign(residual) * np.fmax(np.abs(residual) - alpha, 0) \
/ np.dot(X[:, i], X[:, i])
return bb
def l1_path(X, y, alphas, max_iter=MAX_ITER, verbose=False):
"""
The strategy is described in "Strong rules for discarding predictors in lasso-type problems"
alphas must be an increasing sequence of regularization parameters
WARNING: does not compute intercept
"""
beta = np.zeros((len(alphas), X.shape[1]), dtype=np.float)
alphas_scaled = np.array(alphas).copy() * X.shape[0]
active_set = np.arange(X.shape[1]).tolist()
for k, a in enumerate(alphas_scaled):
if verbose:
print 'Current active set ', active_set
if k > 0:
# .. Strong rules for discarding predictors in lasso-type ..
tmp = np.abs(np.dot(X.T, y - np.dot(X, beta[k - 1])))
strong_active_set = tmp < 2 * alphas_scaled[k] - alphas_scaled[k - 1]
strong_active_set = np.where(strong_active_set)[0]
else:
strong_active_set = np.arange(X.shape[1])
# solve for the current active set
beta[k] = shrinkage(X, y, a, beta[k], active_set, max_iter)
# check KKT in the strong active set
kkt_violations = True
for i in strong_active_set:
tmp = np.dot(X[:, i], y - np.dot(X, beta[k]))
if beta[k, i] != 0 and not np.allclose(tmp, np.abs(alphas_scaled[k])):
active_set.append(i)
if beta[k, i] == 0 and abs(tmp) >= np.abs(alphas_scaled[k]):
active_set.append(i)
else:
# passed KKT for all variables in strong active set, we're done
active_set = np.where(beta[k] != 0)[0].tolist()
kkt_violations = False
if verbose:
print 'No KKT violations on active set'
# .. recompute with new active set
if kkt_violations:
if verbose:
print 'KKT violated on strong active set'
beta[k] = shrinkage(X, y, a, beta[k], active_set, max_iter)
# .. check KKT on all predictors ..
kkt_violations = True
for i in range(X.shape[1]):
tmp = np.dot(X[:, i], y - np.dot(X, beta[k]))
if beta[k, i] != 0 and tmp != np.abs(alphas_scaled[k]):
active_set.append(i)
if beta[k, i] == 0 and abs(tmp) >= np.abs(alphas_scaled[k]):
active_set.append(i)
else:
# passed KKT for all variables, we're done
active_set = np.where(beta[k] != 0)[0].tolist()
kkt_violations = False
if kkt_violations:
if verbose:
print 'KKT violated on full active set'
beta[k] = shrinkage(X, y, a, beta[k], active_set, max_iter)
return beta
def check_kkt_lasso(xr, coef, penalty, tol=1e-3):
"""
Check KKT conditions for Lasso
xr : X'(y - X coef)
"""
nonzero = (coef != 0)
return np.all(np.abs(xr[nonzero] - np.sign(coef[nonzero]) * penalty) < tol) \
and np.all(np.abs(xr[~nonzero] / penalty) <= 1)
if __name__ == '__main__':
np.random.seed(0)
from sklearn import datasets
diabetes = datasets.load_diabetes()
X = diabetes.data
y = diabetes.target
alphas = np.linspace(.001, .1, 5)
coefs_sr_ = l1_path(X, y, alphas, verbose=True, max_iter=1000)
print coefs_sr_
# compare with sklearn
from sklearn.linear_model import lasso_path
models = lasso_path(X, y, eps=0.001, tol=1e-8, alphas=alphas, fit_intercept=False, normalize=False, copy_X=True)
coefs_skl_ = np.array([m.coef_ for m in models])[::-1]
print coefs_sr_ - coefs_skl_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment