init
This commit is contained in:
31
vacc_tools/__init__.py
Normal file
31
vacc_tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from functools import partial
|
||||
from datetime import datetime
|
||||
from typing import Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
_module_time = {}
|
||||
|
||||
|
||||
def print_module_time(
|
||||
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
|
||||
):
|
||||
def now_as_us():
|
||||
return int(datetime.now().timestamp() * 1e6) # in us
|
||||
|
||||
def _pre_forward(suffix, m, inputs):
|
||||
name = f"{type(m).__name__}.{suffix}"
|
||||
_module_time[name] = now_as_us()
|
||||
|
||||
def _post_forward(suffix, m, inputs, outputs):
|
||||
name = f"{type(m).__name__}.{suffix}"
|
||||
start_time = _module_time.pop(name)
|
||||
print(f"{name}: {now_as_us() - start_time} us")
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, module):
|
||||
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
|
||||
m.register_forward_hook(partial(_post_forward, "forward"))
|
||||
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
|
||||
m.register_full_backward_hook(partial(_post_forward, "backward"))
|
||||
Reference in New Issue
Block a user