Created
September 29, 2022 08:46
-
-
Save NickCH-K/05883b134322b9a345600d0af4fb4f16 to your computer and use it in GitHub Desktop.
Runs an evolutionary algorithm to try to do a LASSO
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(data.table) | |
library(glmnet) | |
library(ggplot2) | |
MUTATION_RATE = .5 | |
generate_random_data = function(N = 1000, truth = c(.5, .5, -.5, 0, 0, 0, 0, 0, .1)) { | |
dat = data.table(x = rnorm(N)) | |
dat = dat[, y := generate_predictions(x, truth) + rnorm(N)] | |
return(dat) | |
} | |
generate_predictions = function(x, params) { | |
dat = data.table(x = x) | |
newx = paste0('x_', 2:9) | |
dat[, (newx) := lapply(2:9, \(i) params[i]*x^i)] | |
dat[, x := params[1]*x] | |
return(dat[, rowSums(.SD)]) | |
} | |
model_fitness = function(params, lambda = .1) { | |
dat = generate_random_data() | |
resids = dat$y - generate_predictions(dat$x, params) | |
return(sqrt(sum(resids^2) + lambda*sum(abs(params)))) | |
} | |
crossover = function(params1, params2) { | |
params = list(params1, params2) | |
picks = sample(1:2, length(params1), replace = TRUE) | |
return(sapply(1:length(params1), \(i) params[[picks[i]]][i])) | |
} | |
mutate_params = function(params) { | |
params[sample(1:length(params), 1)] = rnorm(1) | |
return(params) | |
} | |
initialize = function() { | |
inits = list(params = lapply(1:10, \(x) rnorm(9))) | |
inits[['fitness']] = sapply(inits[['params']], model_fitness) | |
return(inits) | |
} | |
parent_picks = function(fits) { | |
probs = 1/fits | |
probs = probs/sum(probs) | |
sample(1:length(probs), 2, prob = probs) | |
} | |
generation = function(pop) { | |
for (i in 1:length(pop)) { | |
parents = parent_picks(pop[['fitness']]) | |
child = crossover(pop[['params']][[parents[1]]],pop[['params']][[parents[2]]]) | |
if (runif(1) < MUTATION_RATE) { | |
child = mutate_params(child) | |
} | |
pop[['params']][[i]] = child | |
pop[['fitness']][which(pop[['fitness']] == max(pop[['fitness']]))] = model_fitness(child) | |
} | |
return(pop) | |
} | |
best_of_generation = function(pop) { | |
thebest = which(pop[['fitness']] == min(pop[['fitness']])) | |
results = c(pop[['params']][[thebest]], pop[['fitness']][thebest]) | |
dat = as.data.table(matrix(results, nrow = 1)) | |
setnames(dat, c(paste0('param_',1:9),'fitness')) | |
return(dat) | |
} | |
pop = initialize() | |
results = best_of_generation(pop) | |
for (i in 2:10000) { | |
print(paste0('Generation ', i, ' at ', Sys.time())) | |
pop = generation(pop) | |
results = rbind(results, best_of_generation(pop)) | |
print(paste0('Best fitness: ', results[i, fitness])) | |
} | |
results[, Generation := 1:.N] | |
# And compare to the average objective of 1000 lassos | |
run_lasso = function(iter) { | |
print(paste0('Lasso ', iter, ' at ', Sys.time())) | |
dat = generate_random_data() | |
setcolorder(dat, c('y','x')) | |
newx = paste0('x_', 2:9) | |
dat[, (newx) := lapply(2:9, \(i) x^i)] | |
m = glmnet(dat[, 2:10], dat$y, lambda = .1) | |
results = c(as.vector(m$beta), model_fitness(as.vector(m$beta))) | |
dat = as.data.table(matrix(results, nrow = 1)) | |
setnames(dat, c(paste0('param_',1:9),'fitness')) | |
return(dat) | |
} | |
lasso_results = rbindlist(lapply(1:1000, run_lasso)) | |
av_lasso = mean(lasso_results$fitness) | |
ggplot(results, aes(x = Generation, y = fitness)) + | |
geom_line() + | |
geom_hline(yintercept = av_lasso, linetype = 'dashed') + | |
annotate(geom = 'label', x = 1, y = av_lasso, hjust = 0, | |
label = paste(scales::number(av_lasso), 'LASSO average'), family = 'serif', | |
size = 13/.pt) + | |
annotate(geom = 'text', x = max(results$Generation), y = results[.N, fitness], | |
hjust = 1, vjust = -1, family = 'serif', size = 13/.pt, | |
label = scales::number(results[.N, fitness], big.mark = ',')) + | |
theme_classic() + | |
theme(text = element_text(family = 'serif', size = 13)) + | |
labs(y = 'Inverse Objective Function', title = 'Evolutionary Algorithm Best, vs. LASSO Average') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment