Files
2026-04-02 04:55:00 +00:00

32 lines
1.0 KiB
Python

from functools import partial
from datetime import datetime
from typing import Union, Tuple
import torch
import torch.distributed
_module_time = {}
def print_module_time(
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
):
def now_as_us():
return int(datetime.now().timestamp() * 1e6) # in us
def _pre_forward(suffix, m, inputs):
name = f"{type(m).__name__}.{suffix}"
_module_time[name] = now_as_us()
def _post_forward(suffix, m, inputs, outputs):
name = f"{type(m).__name__}.{suffix}"
start_time = _module_time.pop(name)
print(f"{name}: {now_as_us() - start_time} us")
for name, m in model.named_modules():
if isinstance(m, module):
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
m.register_forward_hook(partial(_post_forward, "forward"))
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
m.register_full_backward_hook(partial(_post_forward, "backward"))