Created
July 9, 2020 20:12
-
-
Save rhaps0dy/6174458e99b1a6aa76bf9da6c434d97a to your computer and use it in GitHub Desktop.
Natural variational distribution + tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import abc | |
import torch | |
from gpytorch.distributions import MultivariateNormal | |
from gpytorch.lazy import CholLazyTensor | |
from gpytorch.variational._variational_distribution import \ | |
_VariationalDistribution | |
__all__ = ['NaturalVariationalDistribution', 'TrilNaturalVariationalDistribution'] | |
class _AbstractNVD(_VariationalDistribution, metaclass=abc.ABCMeta): | |
def __init__(self, num_inducing_points, batch_shape=torch.Size([]), | |
mean_init_std=1e-3, use_natgrad=True, **kwargs): | |
super().__init__(num_inducing_points=num_inducing_points, | |
batch_shape=batch_shape, mean_init_std=mean_init_std) | |
self._use_natgrad = use_natgrad | |
m = torch.zeros(num_inducing_points).repeat(*batch_shape, 1) | |
self.register_parameter("nat_mean", torch.nn.Parameter(m)) | |
self._register_nat_covar(num_inducing_points, batch_shape) | |
@property | |
def use_natgrad(self): | |
return self._use_natgrad | |
def use_natgrad_(self, use_natgrad=True): | |
self._use_natgrad = use_natgrad | |
@abc.abstractmethod | |
def _register_nat_covar(self, num_inducing_points, batch_shape): | |
pass | |
class NaturalVariationalDistribution(_AbstractNVD): | |
""" | |
A :obj:`~gpytorch.variational._VariationalDistribution` that is defined to | |
be a multivariate normal distribution with a full covariance matrix. | |
Parameterized in terms of its natural parameters, Σ⁻¹μ, -1/2 Σ⁻¹ | |
""" | |
def _register_nat_covar(self, num_inducing_points, batch_shape): | |
cov = -.5 * torch.eye(num_inducing_points) | |
cov = cov.repeat(*batch_shape, 1, 1) | |
self.register_parameter("nat_covar", torch.nn.Parameter(cov)) | |
def forward(self): | |
fun = (_NaturalToMuVarSqrt.apply if self.use_natgrad | |
else _NaturalToMuVarSqrt._forward) | |
mu, L = fun(self.nat_mean, self.nat_covar) | |
return MultivariateNormal(mu, CholLazyTensor(L)) | |
def initialize_variational_distribution(self, prior_dist): | |
chol = prior_dist.lazy_covariance_matrix.cholesky().evaluate() | |
tril_nat_covar = _triangular_inverse(chol, upper=False) | |
nat_covar = tril_nat_covar.transpose(-1, -2) @ tril_nat_covar | |
nat_mean = (prior_dist.mean | |
.unsqueeze(-1) | |
.triangular_solve(chol, upper=False, transpose=False).solution | |
.triangular_solve(chol, upper=False, transpose=True).solution | |
.squeeze(-1)) | |
self.nat_mean.data.copy_(nat_mean) | |
# -.5: because nat_covar = -0.5\Sigma. .5: because we're taking the mean | |
self.nat_covar.data.copy_((nat_covar + nat_covar.transpose(-1, -2)) * (-.5 * .5)) | |
def reparameterise(self): | |
self.initialize_variational_distribution(self.forward()) | |
class TrilNaturalVariationalDistribution(_AbstractNVD): | |
""" | |
A :obj:`~gpytorch.variational._VariationalDistribution` that is defined to | |
be a multivariate normal distribution with a full covariance matrix. | |
Parameterized in terms of its natural mean, and a decomposition of the | |
natural covariance, to ensure the latter stays positive definite when using | |
BFGS or SGD. | |
Parameters are: Σ⁻¹μ, L | |
where L is a lower-triangular matrix, and Σ⁻¹ = LᵀL (Note: this is different than the Cholesky, which is LLᵀ). | |
Claim: any PD matrix Σ can be represented as LᵀL. Proof by construction: | |
Calculate cholesky(Σ⁻¹) = Linv. Then | |
Σ = (Σ⁻¹)⁻¹ = (Linv Linvᵀ)⁻¹ = (Linv⁻¹)ᵀ Linv⁻¹ | |
Since Linv⁻¹ is also lower triangular, set L = Linv⁻¹ and we have found | |
such a representation for Σ. | |
""" | |
def _register_nat_covar(self, num_inducing_points, batch_shape): | |
cov = torch.eye(num_inducing_points) | |
cov = cov.repeat(*batch_shape, 1, 1) | |
self.register_parameter("tril_nat_covar", torch.nn.Parameter(cov)) | |
def forward(self): | |
fun = (_TrilNaturalToMuVarSqrt.apply if self.use_natgrad | |
else _TrilNaturalToMuVarSqrt._forward) | |
mu, L = fun(self.nat_mean, self.tril_nat_covar) | |
return MultivariateNormal(mu, CholLazyTensor(L)) | |
def initialize_variational_distribution(self, prior_dist): | |
chol = prior_dist.lazy_covariance_matrix.cholesky().evaluate() | |
tril_nat_covar = _triangular_inverse(chol, upper=False) | |
# nat_mean = prior_dist.mean | |
nat_mean = (prior_dist.mean | |
.unsqueeze(-1) | |
.triangular_solve(chol, upper=False, transpose=False).solution | |
.triangular_solve(chol, upper=False, transpose=True).solution | |
.squeeze(-1)) | |
self.nat_mean.data.copy_(nat_mean) | |
self.tril_nat_covar.data.copy_(tril_nat_covar) | |
def _triangular_inverse(A, upper=False): | |
eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) | |
return eye.triangular_solve(A, upper=upper).solution | |
def _phi_for_cholesky_(A): | |
A.tril_().diagonal(offset=0, dim1=-2, dim2=-1).mul_(0.5) | |
return A | |
def _cholesky_backward(dout_dL, L, L_inverse): | |
# c.f. https://github.com/pytorch/pytorch/blob/25ba802ce4cbdeaebcad4a03cec8502f0de9b7b3/tools/autograd/templates/Functions.cpp | |
A = L.transpose(-1, -2) @ dout_dL | |
phi = _phi_for_cholesky_(A) | |
grad_input = (L_inverse.transpose(-1, -2) @ phi) @ L_inverse | |
# Symmetrize gradient | |
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5) | |
class _NaturalToMuVarSqrt(torch.autograd.Function): | |
@staticmethod | |
def _forward(nat_mean, nat_covar): | |
try: | |
L_inv = torch.cholesky(-2.0 * nat_covar, upper=False) | |
except RuntimeError as e: | |
if str(e).startswith("cholesky"): | |
raise RuntimeError( | |
"Non-negative-definite natural covariance. You probably " | |
"updated it using an optimizer other than SGD (such as Adam). " | |
"This is not supported.") | |
else: | |
raise e | |
L = _triangular_inverse(L_inv, upper=False) | |
S = L.transpose(-1, -2) @ L | |
mu = (S @ nat_mean.unsqueeze(-1)).squeeze(-1) | |
# Two choleskys are annoying, but we don't have good support for a | |
# LazyTensor of form L.T @ L | |
return mu, torch.cholesky(S, upper=False) | |
@staticmethod | |
def forward(ctx, nat_mean, nat_covar): | |
mu, L = _NaturalToMuVarSqrt._forward(nat_mean, nat_covar) | |
ctx.save_for_backward(mu, L) | |
return mu, L | |
@staticmethod | |
def _backward(dout_dmu, dout_dL, mu, L, C): | |
"""Calculate dout/d(η1, η2), which are: | |
η1 = μ | |
η2 = μμᵀ + LLᵀ = μμᵀ + Σ | |
Thus: | |
dout/dη1 = dout/dμ + dout/dL dL/dη1 | |
dout/dη2 = dout/dL dL/dη1 | |
For L = chol(η2 - η1⋅η1ᵀ). | |
dout/dΣ = _cholesky_backward(dout/dL, L) | |
dout/dη2 = dout/dΣ | |
dΣ/dη1 = -2* (dout/dΣ) μ | |
""" | |
dout_dSigma = _cholesky_backward(dout_dL, L, C) | |
dout_deta1 = dout_dmu - 2*(dout_dSigma @ mu.unsqueeze(-1)).squeeze(-1) | |
return dout_deta1, dout_dSigma | |
@staticmethod | |
def backward(ctx, dout_dmu, dout_dL): | |
"Calculates the natural gradient with respect to nat_mean, nat_covar" | |
mu, L = ctx.saved_tensors | |
C = _triangular_inverse(L, upper=False) | |
return _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C) | |
class _TrilNaturalToMuVarSqrt(torch.autograd.Function): | |
@staticmethod | |
def _forward(nat_mean, tril_nat_covar): | |
L = _triangular_inverse(tril_nat_covar, upper=False) | |
mu = L @ (L.transpose(-1, -2) @ nat_mean.unsqueeze(-1)) | |
return mu.squeeze(-1), L | |
# return nat_mean, L | |
@staticmethod | |
def forward(ctx, nat_mean, tril_nat_covar): | |
mu, L = _TrilNaturalToMuVarSqrt._forward(nat_mean, tril_nat_covar) | |
ctx.save_for_backward(mu, L, tril_nat_covar) | |
return mu, L | |
@staticmethod | |
def backward(ctx, dout_dmu, dout_dL): | |
mu, L, C = ctx.saved_tensors | |
dout_dnat1, dout_dnat2 = _NaturalToMuVarSqrt._backward( | |
dout_dmu, dout_dL, mu, L, C) | |
""" | |
Now we need to do the Jacobian-Vector Product for the transformation: | |
L = inv(chol(inv(-2 θ_cov))) | |
CT C = -2theta_cov | |
so we need to do forward differentiation, starting with sensitivity: | |
θ̇_cov = dout_dnat2 | |
and ending with sensitivity Ċ | |
if B = inv(-2 θ_cov) then: | |
Ḃ = d inv(-2 θ_cov)/dθ_cov ⋅ θ̇_cov = -B (-2 θ̇_cov) B | |
if L = chol(B), B = LLᵀ then (https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf): | |
L̇ = L ϕ(L⁻¹ Ḃ L⁻ᵀ) = L ϕ(2 Lᵀ θ̇_cov L) | |
Then C = inv(L), so | |
Ċ = -C L̇ C = ϕ(-2 Lᵀ θ̇_cov L)C | |
""" | |
A = L.transpose(-2, -1) @ dout_dnat2 @ L | |
phi = _phi_for_cholesky_(-2*A) | |
dout_dtril = phi @ C | |
return dout_dnat1, dout_dtril | |
dL = -L @ phi | |
# Sigma = L @ L.transpose(-1, -2) | |
# dSigma = dL @ L.transpose(-1, -2) + L @ dL.transpose(-1, -2) | |
# nat_mean = C.transpose(-1, -2) @ C @ mu | |
C_mu = C @ mu.unsqueeze(-1) | |
dout_dmu = ( L @ ( L.transpose(-1, -2) @ dout_dnat1.unsqueeze(-1)) | |
+ dL @ C_mu | |
+ L @ (dL.transpose(-1, -2) @ (C.transpose(-1, -2) @ C_mu))) | |
return dout_dmu.squeeze(-1), dout_dtril |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
import gpytorch | |
from gpytorch.distributions import MultivariateNormal | |
from gpytorch.lazy import CholLazyTensor | |
from natural_variational_distribution import ( | |
NaturalVariationalDistribution, TrilNaturalVariationalDistribution) | |
torch.set_default_dtype(torch.float64) | |
class TestNatVariational(unittest.TestCase): | |
def test_invertible_init(self, D=5): | |
mu = torch.randn(D) | |
cov = torch.randn(D, D).tril_() | |
dist = MultivariateNormal(mu, CholLazyTensor(cov)) | |
v_dist = NaturalVariationalDistribution(D) | |
v_dist.initialize_variational_distribution(dist) | |
out_dist = v_dist() | |
assert torch.allclose(out_dist.mean, dist.mean) | |
assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix) | |
def test_natgrad(self, D=5): | |
mu = torch.randn(D) | |
cov = torch.randn(D, D).tril_() | |
dist = MultivariateNormal(mu, CholLazyTensor(cov)) | |
sample = dist.sample() | |
v_dist = NaturalVariationalDistribution(D) | |
v_dist.initialize_variational_distribution(dist) | |
v_dist().log_prob(sample).squeeze().backward() | |
eta1 = mu.clone().requires_grad_(True) | |
eta2 = (mu[:, None]*mu + cov@cov.t()).requires_grad_(True) | |
L = torch.cholesky(eta2 - eta1[:, None]*eta1) | |
dist2 = MultivariateNormal(eta1, CholLazyTensor(L)) | |
dist2.log_prob(sample).squeeze().backward() | |
assert torch.allclose(v_dist.nat_mean.grad, eta1.grad) | |
assert torch.allclose(v_dist.nat_covar.grad, eta2.grad) | |
def test_optimization_zero_error(self, num_inducing=16, num_data=32, D=2): | |
inducing_points = torch.randn(num_inducing, D) | |
class SVGP(gpytorch.models.ApproximateGP): | |
def __init__(self): | |
v_dist = NaturalVariationalDistribution(num_inducing) | |
v_strat = gpytorch.variational.UnwhitenedVariationalStrategy( | |
self, inducing_points, v_dist) | |
super().__init__(v_strat) | |
self.mean_module = gpytorch.means.ZeroMean() | |
self.covar_module = gpytorch.kernels.RBFKernel() | |
def forward(self, x): | |
return MultivariateNormal(self.mean_module(x), self.covar_module(x)) | |
model = SVGP().train() | |
likelihood = gpytorch.likelihoods.GaussianLikelihood().train() | |
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data) | |
X = torch.randn((num_data, D)) | |
y = torch.randn(num_data) | |
def loss(): | |
return -mll(model(X), y) | |
optimizer = torch.optim.SGD( | |
model.variational_strategy._variational_distribution.parameters(), | |
lr=float(num_data)) | |
model.variational_strategy._variational_distribution.use_natgrad_(False) | |
optimizer.zero_grad() | |
loss().backward() | |
grad_nat_mean, grad_nat_covar = ( | |
model.variational_strategy._variational_distribution.nat_mean.grad.clone(), | |
model.variational_strategy._variational_distribution.nat_covar.grad.clone()) | |
model.variational_strategy._variational_distribution.use_natgrad_(True) | |
optimizer.zero_grad() | |
loss().backward() | |
natgrad_nat_mean, natgrad_nat_covar = ( | |
model.variational_strategy._variational_distribution.nat_mean.grad.clone(), | |
model.variational_strategy._variational_distribution.nat_covar.grad.clone()) | |
assert not torch.allclose(grad_nat_mean, natgrad_nat_mean) | |
assert not torch.allclose(grad_nat_covar, natgrad_nat_covar) | |
optimizer.step() # Now we should be at the optimum | |
model.variational_strategy._variational_distribution.use_natgrad_(True) | |
optimizer.zero_grad() | |
loss().backward() | |
natgrad_nat_mean2, natgrad_nat_covar2 = ( | |
model.variational_strategy._variational_distribution.nat_mean.grad.clone(), | |
model.variational_strategy._variational_distribution.nat_covar.grad.clone()) | |
assert torch.allclose(natgrad_nat_mean2, torch.zeros(())) | |
assert torch.allclose(natgrad_nat_covar2, torch.zeros(())) | |
model.variational_strategy._variational_distribution.use_natgrad_(False) | |
optimizer.zero_grad() | |
loss().backward() | |
grad_nat_mean, grad_nat_covar = ( | |
model.variational_strategy._variational_distribution.nat_mean.grad.clone(), | |
model.variational_strategy._variational_distribution.nat_covar.grad.clone()) | |
assert torch.allclose(grad_nat_mean, torch.zeros(())) | |
assert torch.allclose(grad_nat_covar, torch.zeros(())) | |
class TestTrilNatVariational(unittest.TestCase): | |
def test_invertible_init(self, D=5): | |
mu = torch.randn(D) | |
cov = torch.randn(D, D).tril_() | |
dist = MultivariateNormal(mu, CholLazyTensor(cov)) | |
v_dist = TrilNaturalVariationalDistribution(D) | |
v_dist.initialize_variational_distribution(dist) | |
out_dist = v_dist() | |
assert torch.allclose(out_dist.mean, dist.mean) | |
assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix) | |
def test_nat_jvp(self, D=5): | |
mu = torch.randn(D) | |
cov = torch.randn(D, D) | |
cov = cov @ cov.t() | |
dist = MultivariateNormal(mu, CholLazyTensor(cov.cholesky())) | |
sample = dist.sample() | |
v_dist = TrilNaturalVariationalDistribution(D) | |
v_dist.initialize_variational_distribution(dist) | |
v_dist().log_prob(sample).squeeze().backward() | |
dout_dnat1 = v_dist.nat_mean.grad | |
dout_dnat2 = v_dist.tril_nat_covar.grad | |
v_dist_ref = NaturalVariationalDistribution(D) | |
v_dist_ref.initialize_variational_distribution(dist) | |
v_dist_ref().log_prob(sample).squeeze().backward() | |
dout_dnat1_noforward_ref = v_dist_ref.nat_mean.grad | |
dout_dnat2_noforward_ref = v_dist_ref.nat_covar.grad | |
# Use jax for forward-mode AD, and JVPs. | |
import jax.numpy as np | |
from jax import jvp | |
import os | |
assert os.environ['JAX_ENABLE_X64'] != "" | |
def f(nat_mean, nat_covar): | |
"Transform nat_covar to L" | |
Sigma = np.linalg.inv(-2*nat_covar) | |
mu = nat_mean | |
return mu, np.tril(np.linalg.inv(np.linalg.cholesky(Sigma))) | |
(np_mu, np_tril_nat_covar), (np_dout_dmu_ref, np_dout_dnat2_ref) = jvp( | |
f, | |
(np.asarray(v_dist_ref.nat_mean.detach()), np.asarray(v_dist_ref.nat_covar.detach())), | |
(np.asarray(dout_dnat1_noforward_ref), np.asarray(dout_dnat2_noforward_ref))) | |
assert np.allclose( | |
np_tril_nat_covar, v_dist.tril_nat_covar.detach().numpy()), "Sigma transformation" | |
assert np.allclose(np_dout_dnat2_ref, dout_dnat2.numpy()), "Sigma gradient" | |
assert np.allclose(np_mu, v_dist.nat_mean.detach().numpy()), "mu transformation" | |
assert np.allclose(np_dout_dmu_ref, dout_dnat1.numpy()), "mu transformation" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment