Last active
May 20, 2021 23:59
-
-
Save redwrasse/e698fd8c69d08a830c622b29364ef04f to your computer and use it in GitHub Desktop.
toy implementation of 'space-time as a contrastive random walk'
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pseudocode impl | |
# Algorithm 1 Pseudocode in a PyTorch-like style. | |
# for x in loader: # x: batch with B sequences | |
# # Split image into patches | |
# # B x C x T x H x W -> B x C x T x N x h x w | |
# x = unfold(x, (patch_size, patch_size)) | |
# x = spatial_jitter(x) | |
# # Embed patches (B x C x T x N) | |
# v = l2_norm(resnet(x)) | |
# # Transitions from t to t+1 (B x T-1 x N x N) | |
# A = einsum("bcti,bctj->btij", | |
# v[:,:,:-1], v[:,:,1:]) / temperature | |
# # Transition energies for palindrome graph | |
# AA = cat((A, A[:,::-1].transpose(-1,-2), 1) | |
# AA[rand(AA) < drop_rate] = -1e10 # Edge dropout | |
# At = eye(P) # Init. position | |
# # Compute walks | |
# for t in range(2*T-2): | |
# At = bmm(softmax(AA[:,t]), dim=-1), At) | |
# # Target is the original node | |
# loss = At[[range(P)]*B]].log() | |
import torch | |
import torch.nn.functional as F | |
import random | |
import math | |
# made up dimensions | |
NUM_BATCHES = 100 | |
B = 4 # num. sequences in batch | |
C = 1 # original channel dim. | |
T = 50 # num. timesteps | |
H = 6 # frame height | |
W = 6 # frame width | |
h, w = 3, 3 # patch size | |
PATCH_STEP = 1 | |
D = 2 # mock resnet embedding dim | |
TAU = 1.0 # temperature | |
mock_loader = [torch.randn(B, C, T, H, W) for _ | |
in range(NUM_BATCHES)] | |
def circularly_polarized_loader(): | |
""" 'circularly polarized' sequence of frames """ | |
loader = [] | |
for _ in range(NUM_BATCHES): | |
x = torch.clamp(torch.zeros(B, C, T, H, W), | |
min=0.01, | |
max=1.) | |
for b in range(B): | |
offset = 2 * 3.14 * random.randint(0, T - 1) / T | |
for t in range(T): | |
theta = 2 * 3.14 * t / T + offset | |
snv = math.sin(theta) | |
cosv = math.cos(theta) | |
h_ix = min(int((snv + 1.) * H / 2), H-1) | |
w_ix = min(int((cosv + 1.) * W / 2), W-1) | |
x[b, :, t, h_ix, w_ix] = 1. | |
loader.append(x) | |
return loader | |
def pseudocode_impl(loader): | |
phi = MockResnetEmbedding() | |
optimizer = torch.optim.SGD( | |
phi.parameters(), | |
lr=0.1 | |
) | |
max_iters = 10000 | |
for iter in range(max_iters): | |
random.shuffle(loader) | |
cross_ent_loss = 0. | |
for x in loader: | |
optimizer.zero_grad() | |
# x a batch with B sequences | |
# shape (B, C, T, H, W) | |
# B num. batches, C num. channels, | |
# T num. timesteps, H frame height, W frame width | |
# Split image into patches | |
# B x C x T x H x W -> B x C x T x N x h x w | |
xp = frame_to_patches(x, | |
dim_h=3, | |
dim_w=4, | |
patch_height=h, | |
patch_width=w, | |
patch_step=PATCH_STEP) | |
xp = spatial_jitter(xp) | |
v = phi(xp) | |
# Transitions (all inner products) from t to t+1 (B x T-1 x N x N) | |
E = torch.einsum("bcti,bctj->btij", | |
v[:,:,:-1], v[:,:,1:]) | |
EP = palindromed_transitions(E) | |
AA = transition_probs(EP, out_dim=-1) | |
MP = matrix_product(AA) | |
H = cross_entropy_loss(MP) | |
cross_ent_loss += H | |
params = list(phi.parameters()) | |
# backpropagate | |
H.backward() | |
optimizer.step() | |
cross_ent_loss /= NUM_BATCHES | |
print(f'(i={iter}) cross entropy loss: {H}') | |
def cross_entropy_loss(MP): | |
H = torch.mean(torch.einsum('bii', -torch.log(MP)), | |
dim=0) | |
return H | |
def matrix_product(AA): | |
""" returns matrix product of individual transition probs | |
Given AA of shape B x T x N x N | |
computes product of all matrices along T axis, | |
returns entity of shape B X N X N | |
""" | |
T = AA.shape[1] | |
MP = AA[:, 0, :, :] | |
for t in range(1, T): | |
m = AA[:, t, :, :] | |
MP = torch.matmul(MP, m) | |
return MP | |
def transition_probs(E, out_dim=-1): | |
""" Returns normalized transition probs with temperature and softmax """ | |
# verify are normalized | |
return F.softmax(E / TAU, dim=out_dim) | |
def palindromed_transitions(E): | |
"""" | |
Transition energies with palindromed transition energies appended | |
""" | |
# need to reverse time direction and also i -> j becomes j -> i | |
E_reversed = torch.flip(E, dims=[1]).transpose(-1, -2) | |
palindromed = torch.cat([E, E_reversed], dim=1) | |
return palindromed | |
def spatial_jitter(xp): | |
# tbd | |
return xp | |
class MockResnetEmbedding(torch.nn.Module): | |
""" mock embedding on x given of shape | |
B x C x T x N x h x w | |
D embedding dim. | |
Embed patches (B x D x T x N) | |
""" | |
# TODO("probably want to make this convolutional") | |
def __init__(self): | |
super(MockResnetEmbedding, self).__init__() | |
mock_embed_matrix = torch.randn(size=(D, C, h*w)) | |
self.embed_matrix = torch.nn.Parameter(mock_embed_matrix) | |
def forward(self, x): | |
xf = x.flatten(4, 5) | |
xe = torch.tensordot(self.embed_matrix, xf, dims=([1, 2], [1, 4])) \ | |
.permute(1, 0, 2, 3) | |
nm = torch.linalg.norm(xe, dim=1, keepdim=True) | |
xe_normalized = xe.div(nm) | |
# print(f'xe: {xe_normalized.shape}') | |
# check is in fact normalized at dim 1 | |
return xe_normalized | |
def frame_to_patches(x, dim_h, dim_w, patch_height, patch_width, patch_step): | |
""" | |
Given a tensor x of dim ...H * W with H, W frame height and width, | |
splits it up into patches of patch_height h and patch_width w and step patch_step. | |
Returns a tensor of shape ... N * h * w, where N is the number of patches created for | |
the frame. | |
:param x: | |
:param dim_h: | |
:param dim_w: | |
:param patch_size: | |
:param patch_step: | |
:return: | |
""" | |
assert (dim_w == dim_h + 1), "error: expected dim_w to follow dim_h" | |
return x.unfold(dim_h, patch_height, patch_step)\ | |
.unfold(dim_w, patch_width, patch_step)\ | |
.flatten(dim_h, dim_h+1) | |
def run_pseudocode_impl(): | |
loader = circularly_polarized_loader() | |
pseudocode_impl(loader) | |
if __name__ == "__main__": | |
run_pseudocode_impl() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment