Last active
June 7, 2021 22:22
-
-
Save cscherrer/ad80243833fa0a30b448ac58b77b6e5f to your computer and use it in GitHub Desktop.
OnlineStat for log-weighetd Gaussian
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PositiveFactorizations | |
using LinearAlgebra | |
using StatsFuns | |
using Random | |
using OnlineStatsBase | |
using Statistics | |
const OSB = OnlineStatsBase | |
const bessel = OSB.bessel | |
mutable struct Gaussian{T} <: OnlineStat{Union{Tuple, NamedTuple, AbstractVector}, } where T<:Number | |
value::Matrix{T} | |
A::Matrix{T} # x'x/n | |
b::Vector{T} # 1'x/n | |
log∑w::T | |
log∑w²::T | |
L::Matrix{T} # Lower Cholesky factor | |
Lsync::Bool # Indicates whether L is up to date | |
weight | |
n::Int | |
end | |
# Kish's Effective Sample Size, see https://en.wikipedia.org/wiki/Effective_sample_size | |
n_eff(o::Gaussian) = exp(2 * o.log∑w - o.log∑w²) | |
function Gaussian(::Type{T}, p::Int=0; weight = EqualWeight()) where T<:Number | |
logε = nextfloat(typemin(T)) | |
Gaussian(zeros(T,p,p), zeros(T,p,p), zeros(T,p), logε, logε, zeros(T,p,p), false, weight, 0) | |
end | |
Gaussian(p::Int=0; weight = EqualWeight()) = Gaussian(Float64, p; weight=weight) | |
# The computation for γ, | |
# | |
# γ = min(o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w), 1.0) | |
# | |
# could use come explanation. Let's first look at the first argument inside the | |
# `min`, | |
# | |
# o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w) | |
# | |
# The first factor, | |
# | |
# o.weight(o.n += 1) | |
# | |
# is to allow the use of built-in weights from OnlineStats.jl. The last factor, | |
# | |
# logistic(ℓ - o.log∑w) | |
# | |
# takes into account that each observtation has its own log-weight. But now | |
# we've counted the weights twice, so there's a question of how to correct for | |
# this. We could take a square root to get the geometric mean of the two | |
# effects, but this would leave neither of them working as expected. | |
# | |
# With OnlineStats.EqualWeight, the weight for the nth observation would be 1/n. | |
# So we multiply by n (the second factor). | |
# | |
# The `logistic` factor is typically O(1/n), but there are no guarantees about | |
# this. So the `min` is to account for the (hopefuly rare) case where | |
# | |
# o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w) > 1 | |
# | |
# Finally, allowing γ=1 causes problems, because this makes the distribution | |
# collapse to a point. So as a heuristic, we set this to 0.99 to prevent this | |
# problem. | |
function OSB._fit!(o::Gaussian{T}, xℓ) where {T} | |
o.Lsync = false | |
(x,ℓ) = xℓ | |
γ = min(o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w), 0.99) | |
o.log∑w = logaddexp(o.log∑w, ℓ) | |
o.log∑w² = logaddexp(o.log∑w, 2ℓ) | |
if isempty(o.A) | |
p = length(x) | |
o.b = zeros(T, p) | |
o.A = zeros(T, p, p) | |
o.L = zeros(T, p, p) | |
o.value = zeros(T, p, p) | |
end | |
OSB.smooth!(o.b, x, γ) | |
OSB.smooth_syr!(o.A, x, γ) | |
end | |
OSB.nvars(o::Gaussian) = size(o.A, 1) | |
function OSB.value(o::Gaussian) | |
o.value[:] = Matrix(Symmetric((o.A - o.b * o.b'))) | |
o.value | |
end | |
function OSB._merge!(o::Gaussian, o2::Gaussian) | |
o.Lsync = false | |
o.n += o2.n | |
γ = logistic(o.log∑w - o2.log∑w) | |
o.log∑w = logaddexp(o.log∑w, o2.log∑w) | |
o.log∑w² = logaddexp(o.log∑w², o2.log∑w²) | |
OSB.smooth!(o.A, o2.A, γ) | |
OSB.smooth!(o.b, o2.b, γ) | |
end | |
Statistics.cov(o::Gaussian) = value(o) | |
Statistics.mean(o::Gaussian) = o.b | |
Statistics.var(o::Gaussian; kw...) = diag(value(o; kw...)) | |
function Statistics.cor(o::Gaussian; kw...) | |
value(o; kw...) | |
v = 1.0 ./ sqrt.(diag(o.value)) | |
rmul!(o.value, Diagonal(v)) | |
lmul!(Diagonal(v), o.value) | |
o.value | |
end | |
function LinearAlgebra.cholesky(o::Gaussian) | |
o.Lsync && return Cholesky(LowerTriangular(o.L), :L, 0) | |
copyto!(o.L, value(o)) | |
C = cholesky!(Positive, o.L) | |
o.Lsync = true | |
return C | |
end | |
function Random.rand!(rng::AbstractRNG, x::AbstractArray, o::Gaussian{T}) where {T} | |
randn!(rng, x) | |
L = cholesky(o).L | |
lmul!(L, x) | |
x .+= mean(o) | |
return x | |
end | |
function Base.rand(rng::AbstractRNG, o::Gaussian{T}) where {T} | |
x = Vector{T}(undef, OSB.nvars(o)) | |
rand!(rng, x, o) | |
end | |
# Some checks | |
# | |
# o = Gaussian() | |
# for j in 1:1000 | |
# fit!(o, (randn(3), randn())) | |
# end | |
# value(o) | |
# mean(o) | |
# cov(o) | |
# cholesky(o) | |
# rand!(zeros(3), o) | |
# rand(o) | |
using MeasureTheory | |
using TransformVariables | |
t = as𝕀 | |
p = Pullback(t, Beta(4,2)) | |
q0 = Normal(0,10) | |
logdensity(p, q0, randn()) | |
o = Gaussian(weight=EqualWeight()) | |
while min(nobs(o), n_eff(o)) < 10 | |
x = rand(q0) | |
ℓ = logdensity(p, q0, x) | |
fit!(o, (x, ℓ)) | |
end | |
while n_eff(o) < 1000 | |
n = n_eff(o) | |
μ = mean(o)[1] | |
σ = std(o)[1] | |
q = Normal(μ,σ) | |
# Train in minibatches | |
while n_eff(o) < n + 100 | |
x = rand(q) | |
ℓ = logdensity(p, q, x) | |
fit!(o, (x, ℓ)) | |
end | |
end | |
using UnicodePlots | |
xx = 0.01:0.01:0.99 | |
q = Normal(mean(o)[1], std(o)[1]) | |
μμ = [density(Pushforward(t,p), x) for x in xx]; | |
νν = [density(Pushforward(t, q), x) for x in xx]; | |
factor = sum(μμ) / sum(νν) | |
plt = lineplot(xx, νν*factor); | |
lineplot!(plt, xx, μμ) | |
# julia> lineplot!(plt, xx, μμ) | |
# ┌────────────────────────────────────────┐ | |
# 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠀⠀⠀⠀⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠛⠥⡉⢢⡀⠀⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠤⡟⠉⠀⠀⠀⠀⠈⢢⡑⡄⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡗⠊⠀⠀⠀⠀⠀⠀⠀⠀⠑⣵⠀⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡮⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⡇⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢹⠀⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣔⠕⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⣇⠀⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢻⡄⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⣧⠀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣤⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢿⡀│ | |
# │⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡶⠛⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⢇│ | |
# 0 │⣀⣀⣀⣀⣀⣤⠴⠞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ | |
# └────────────────────────────────────────┘ | |
# 0 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment