Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created July 30, 2023 17:47
Show Gist options
  • Save madebyollin/e6e217a77db94e2a960ab6ccbd627db9 to your computer and use it in GitHub Desktop.
Save madebyollin/e6e217a77db94e2a960ab6ccbd627db9 to your computer and use it in GitHub Desktop.
Helper for logging output activation-map statistics for a PyTorch module, using forward hooks
def summarize_tensor(x):
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
class ModelActivationPrinter:
def __init__(self, module, submodules_to_log):
self.id_to_name = {
id(module): str(name) for name, module in module.named_modules()
}
self.submodules = submodules_to_log
self.hooks = []
def __enter__(self, *args, **kwargs):
def log_activations(m, m_in, m_out):
label = self.id_to_name.get(id(m), "(unnamed)") + " output"
if isinstance(m_out, (tuple, list)):
m_out = m_out[0]
label += "[0]"
print(label.ljust(48) + summarize_tensor(m_out))
for m in self.submodules:
self.hooks.append(m.register_forward_hook(log_activations))
return self
def __exit__(self, *args, **kwargs):
for hook in self.hooks:
hook.remove()
if __name__ == "__main__":
import torch
model = torch.nn.Sequential(
torch.nn.Linear(1, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1)
)
with ModelActivationPrinter(model, model):
y = model(torch.zeros(1, 1))
@madebyollin
Copy link
Author

If you want to plot little images too

import torchvision.transforms.functional as TF
import torch.nn.functional as F

def summarize_tensor(x):
    return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"


class ModelActivationPrinter:
    def __init__(self, module, submodules_to_log):
        self.id_to_name = {
            id(module): str(name) for name, module in module.named_modules()
        }
        self.submodules = submodules_to_log
        self.hooks = []

    def __enter__(self, *args, **kwargs):
        def log_activations(m, m_in, m_out):
            label = self.id_to_name.get(id(m), "(unnamed)") + " output"
            if isinstance(m_out, (tuple, list)):
                m_out = m_out[0]
                label += "[0]"
            print(label.ljust(48) + summarize_tensor(m_out))
            if m_out.ndim == 4:
                # visualize first three channels as an image
                size = max(16, m_out.shape[-1])
                m_out_im = F.interpolate(m_out[:1024//size, :3].div(m_out.abs().max() + 1e-5).add(0.5).clamp(0, 1), (size, size))
                display(TF.to_pil_image(th.cat(tuple(m_out_im), -1)))

        for m in self.submodules:
            self.hooks.append(m.register_forward_hook(log_activations))
        return self

    def __exit__(self, *args, **kwargs):
        for hook in self.hooks:
            hook.remove()
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment