Last active
August 9, 2019 23:44
-
-
Save ksivaman/2f05b60dc4b63d94a254c441ad8be4ac to your computer and use it in GitHub Desktop.
Gram matrix calculator for neural style transfer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gram_matrix(tensor): | |
# get the batch_size, depth, height, and width of the Tensor | |
_, d, h, w = tensor.size() | |
# reshape so we're multiplying the features for each channel | |
tensor = tensor.view(d, h * w) | |
# calculate the gram matrix | |
gram = torch.mm(tensor, tensor.t()) | |
return gram |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment