Last active
February 2, 2021 10:01
-
-
Save yusuke0519/4945c213a49332d683c77203c62a4247 to your computer and use it in GitHub Desktop.
Store activations during the forward path using hooks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MLPEncoder(nn.Module): | |
def __init__(self): | |
super(MLPEncoder, self).__init__() | |
# TODO: Fix hard coding | |
self.model = nn.Sequential(OrderedDict([ | |
('layer1', nn.Linear(784, 400)), | |
('relu1', nn.ReLU()), | |
('layer2', nn.Linear(400, 400)), | |
('relu2', nn.ReLU()), | |
('layer3', nn.Linear(400, 200)), | |
('relu3', nn.ReLU()), | |
('layer4', nn.Linear(200, 200)), | |
('relu4', nn.ReLU()) | |
])) | |
print(self.model) | |
# add hook to store activatiosns | |
self.activations = {} | |
def store_activations(model, input, output): | |
self.activations[model.__name__] = output | |
for name, layer in self.model.named_children(): | |
layer.__name__ = name | |
layer.register_forward_hook(store_activations) | |
def forward(self, x): | |
return self.model(x.view(-1, 784)) | |
def get_activations(name): | |
return self.activations[name] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hookを使ってpytorchの各中間層の出力を獲得する.
Lossを計算する場合を想定してdetachはしてない.使わずに可視化するだけとかなら
self.activations[model.__name__] = output.detach()
したほうが良い.