from functools import partial from datetime import datetime from typing import Union, Tuple import torch import torch.distributed _module_time = {} def print_module_time( model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module] ): def now_as_us(): return int(datetime.now().timestamp() * 1e6) # in us def _pre_forward(suffix, m, inputs): name = f"{type(m).__name__}.{suffix}" _module_time[name] = now_as_us() def _post_forward(suffix, m, inputs, outputs): name = f"{type(m).__name__}.{suffix}" start_time = _module_time.pop(name) print(f"{name}: {now_as_us() - start_time} us") for name, m in model.named_modules(): if isinstance(m, module): m.register_forward_pre_hook(partial(_pre_forward, "forward")) m.register_forward_hook(partial(_post_forward, "forward")) m.register_full_backward_pre_hook(partial(_pre_forward, "backward")) m.register_full_backward_hook(partial(_post_forward, "backward"))