Skip to content

Instantly share code, notes, and snippets.

@myrddian
Last active July 19, 2019 04:53
Show Gist options
  • Save myrddian/b08fdccb8d38d73d928b50cf4e3eb1ad to your computer and use it in GitHub Desktop.
Save myrddian/b08fdccb8d38d73d928b50cf4e3eb1ad to your computer and use it in GitHub Desktop.
library(MCMCpack)
example_prior <- list(mean=3,sd=2)
hyper_parameters <- list(alpha=1, beta=1)
data_samples = rnorm(100,6,3)
random_samples <- rnorm(100,20,5)
find_likelyhood <- function(target_model, target) {
return(dnorm(target,target_model$mean, target_model$sd))
}
parameter_log_likelyhood <- function(data_samples, prior_model, target_model) {
likely_hood = 0
log_prior = log((dnorm(target_model$mean, prior_model$mean, prior_model$sd)))
for(index in 1:length(data_samples)) {
sample <- find_likelyhood(target_model, data_samples[index])
likely_hood = log(sample) + likely_hood
}
return (likely_hood + log_prior)
}
new_sd_based_on_mean <- function(data_samples, target_model, hyper_parameter) {
mean_sum = 0
for(index in 1:length(data_samples)) {
mean_sum = mean_sum + (data_samples[index] - target_model$mean)
}
mean_sum = mean_sum /2
alpha = hyper_parameter$alpha + (length(data_samples)/2)
beta = hyper_parameter$beta + mean_sum
if(beta < 1) {
target_model$sd = 1
target_model$mean = target_model$mean + target_model$sd
return(target_model)
}
target_model$sd = (rinvgamma(1, alpha, beta))
target_model$mean = target_model$mean + target_model$sd
return(target_model)
}
find_map <- function(data_samples, prior, hyper_parameters, iter=1000, max_jump = 5) {
ret_val <- prior
best_log = -999999
for(iteration in 1:iter) {
permute <- ret_val
#Change mean
if(sample(0:1,1) == 0) {
permute$mean = permute$mean - runif(1,0,max_jump)
} else {
permute$mean = permute$mean + runif(1,0,max_jump)
}
#Change sd based on new mean
permute = new_sd_based_on_mean(data_samples, permute, hyper_parameters)
if(permute$sd >= 0 && permute$mean >= 1) {
current_log = parameter_log_likelyhood(data_samples, prior, permute)
if(is.nan(as.numeric(current_log)) == FALSE) {
if(current_log > best_log) {
best_log = current_log
ret_val = permute
print(paste(paste(paste(paste(paste("New Attr Found (mean): ",permute$mean), " (sd): "), permute$sd), " Log: "), current_log))
}
}
}
}
return(ret_val)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment