Skip to content

Instantly share code, notes, and snippets.

@ksivaman
Last active August 9, 2019 23:44
Show Gist options
  • Save ksivaman/2f05b60dc4b63d94a254c441ad8be4ac to your computer and use it in GitHub Desktop.
Save ksivaman/2f05b60dc4b63d94a254c441ad8be4ac to your computer and use it in GitHub Desktop.
Gram matrix calculator for neural style transfer
def gram_matrix(tensor):
# get the batch_size, depth, height, and width of the Tensor
_, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(d, h * w)
# calculate the gram matrix
gram = torch.mm(tensor, tensor.t())
return gram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment