32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
from functools import partial
|
|
from datetime import datetime
|
|
from typing import Union, Tuple
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
_module_time = {}
|
|
|
|
|
|
def print_module_time(
|
|
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
|
|
):
|
|
def now_as_us():
|
|
return int(datetime.now().timestamp() * 1e6) # in us
|
|
|
|
def _pre_forward(suffix, m, inputs):
|
|
name = f"{type(m).__name__}.{suffix}"
|
|
_module_time[name] = now_as_us()
|
|
|
|
def _post_forward(suffix, m, inputs, outputs):
|
|
name = f"{type(m).__name__}.{suffix}"
|
|
start_time = _module_time.pop(name)
|
|
print(f"{name}: {now_as_us() - start_time} us")
|
|
|
|
for name, m in model.named_modules():
|
|
if isinstance(m, module):
|
|
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
|
|
m.register_forward_hook(partial(_post_forward, "forward"))
|
|
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
|
|
m.register_full_backward_hook(partial(_post_forward, "backward"))
|