Skip to content

Instantly share code, notes, and snippets.

@kabouzeid
Created August 16, 2024 16:48
Show Gist options
  • Save kabouzeid/a1e2432a51b34a106001db92e332a94b to your computer and use it in GitHub Desktop.
Save kabouzeid/a1e2432a51b34a106001db92e332a94b to your computer and use it in GitHub Desktop.
# Copyright (c) Karim Abou Zeid
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import STEP_OUTPUT
from typing_extensions import override
class TimeMonitor(Callback):
def __init__(self, prog_bar=True) -> None:
super().__init__()
self.prog_bar = prog_bar
@override
def setup(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: str,
) -> None:
if stage != "fit":
return
if not isinstance(trainer.profiler, pl.profilers.SimpleProfiler):
raise MisconfigurationException(
"Cannot use `TimeMonitor` callback without `Trainer(profiler='simple')`."
)
@override
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
pl_module.log_dict(
{
"batch_time": trainer.profiler.recorded_durations["run_training_batch"][
-1
],
"data_time": trainer.profiler.recorded_durations[
"[_TrainingEpochLoop].train_dataloader_next"
][-1],
},
on_step=True,
on_epoch=False,
prog_bar=True,
)
@kabouzeid
Copy link
Author

kabouzeid commented Aug 16, 2024

Add TimeMonitor to your PyTorch Lightning callbacks and add profiler='simple' to you Trainer init args.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment