Skip to content

Instantly share code, notes, and snippets.

@ksivaman
Created August 9, 2019 23:42
Show Gist options
  • Save ksivaman/e5bca5f9ff27efd2db35621a8b1c327d to your computer and use it in GitHub Desktop.
Save ksivaman/e5bca5f9ff27efd2db35621a8b1c327d to your computer and use it in GitHub Desktop.
Gram matrix calculator for neural style transfer
def gram_matrix(tensor):
# get the batch_size, depth, height, and width of the Tensor
_, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(d, h * w)
# calculate the gram matrix
gram = torch.mm(tensor, tensor.t())
return gramgr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment