Skip to content

Instantly share code, notes, and snippets.

Created August 8, 2024 17:11
Show Gist options
  • Save jxbz/fe235ee1c72b8b41ccd0d02b43378cf2 to your computer and use it in GitHub Desktop.
Save jxbz/fe235ee1c72b8b41ccd0d02b43378cf2 to your computer and use it in GitHub Desktop.
""" Computing zeroth matrix powers via Lakic 1998.
paper: "On the Computation of the Matrix k-th Root"
Suppose we have a matrix G = USV^T and we want to compute
G^0 defined via G^0 = UV^T. We might want to do this to run
"stochastic spectral descent" of Carlson et al 2015. The
naive way to do this is via the SVD. But we can also just do
(GG^T)^(-1/2) G or alternatively G (G^TG)^(-1/2) and apply
the iterative method from Lakic 1998.
In particular, we implement the first special case of Alg 1
in that paper.
import torch
def zeroth_power_via_newton(G, steps=20):
device = G.device
d1, d2 = G.shape
d = min(d1, d2)
# store the smaller of the squares as S
S = G @ G.t() if d1 < d2 else G.t() @ G
S_norm = torch.linalg.matrix_norm(S, ord='fro') # there is freedom here. See Lakic (1998) Thm 2.3
S /= S_norm
# Now let's set up the state for the Lakic (1998) method
N = S
X = torch.eye(d).to(device)
I = torch.eye(d).to(device)
# Now let's run the iteration
for _ in range(steps):
U = (3 * I - N) / 2
X = X @ U
N = N @ U @ U
# X should now store either (G G^T)^(-1/2) or (G^T G)^(-1/2)
return X @ G / S_norm.sqrt() if d1 < d2 else G @ X / S_norm.sqrt()
def zeroth_power_via_svd(G):
U,S,V = G.svd()
return U @ V.t()
# Let's test it on a random Gaussian matrix
G = torch.randn(100, 100)
G_zero_newton = zeroth_power_via_newton(G)
G_zero_svd = zeroth_power_via_svd(G)
# Check the singular values are all one
# Print the error
# Seems like relative Frobenius error is sensible here
error = torch.linalg.matrix_norm(G_zero_newton - G_zero_svd, ord='fro')
error /= torch.linalg.matrix_norm(G_zero_svd, ord='fro')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment