Skip to content

Instantly share code, notes, and snippets.

View ShoufaChen's full-sized avatar

Shoufa Chen ShoufaChen

View GitHub Profile
@ShoufaChen
ShoufaChen / model_activation_printer.py
Created September 25, 2023 09:02 — forked from madebyollin/model_activation_printer.py
Helper for logging output activation-map statistics for a PyTorch module, using forward hooks
def summarize_tensor(x):
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
class ModelActivationPrinter:
def __init__(self, module, submodules_to_log):
self.id_to_name = {
id(module): str(name) for name, module in module.named_modules()
}
self.submodules = submodules_to_log