330 lines
10 KiB
Python
330 lines
10 KiB
Python
"""
|
|
This module provides mechanisms for tracing torch's module and function's execution,
|
|
and output the trace into a json file.
|
|
|
|
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
|
|
If not, no trace will be applied.
|
|
|
|
Inside your module, create your module's tracer functions by using `get_trace_api`.
|
|
You will get three functions:
|
|
* `@trace_time(name)`: decorator to trace the execution of a function.
|
|
```python
|
|
@trace_time("my_func")
|
|
def my_func(x):
|
|
...
|
|
```
|
|
* `@trace_autograd_function()`: decorator to trace the execution of forward
|
|
and backward of a user defined `torch.autograd.Function` operator.
|
|
```python
|
|
@trace_autograd_function()
|
|
class MyAutogradFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
...
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
...
|
|
```
|
|
* `register_module_trace()`: function to register trace a model (`nn.Module`),
|
|
it applies traces recursively to a torch model by enumerating all nn.Module
|
|
and register tracer to their forward and backward function. Only applying to
|
|
top level nn.Module is recommended.
|
|
```python
|
|
model = Model()
|
|
register_module_trace(model)
|
|
```
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
|
|
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
|
|
|
|
# pylint: disable=missing-docstring
|
|
|
|
|
|
@dataclass
|
|
class TraceEntry:
|
|
name: str
|
|
cat: str
|
|
pid: int
|
|
tid: int
|
|
ts: int
|
|
ph: str
|
|
args: str = None
|
|
|
|
def to_json_str(self):
|
|
d = asdict(self)
|
|
if self.args is None:
|
|
d.pop("args")
|
|
return json.dumps(d, separators=(",", ": "))
|
|
|
|
|
|
class LogFiles:
|
|
def __init__(self) -> None:
|
|
self.loggers = {}
|
|
|
|
def get(self, file_prefix, rank, pid):
|
|
os.makedirs("log", exist_ok=True)
|
|
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
|
|
if not fpath in self.loggers:
|
|
self.loggers[fpath] = open(fpath, "w")
|
|
return self.loggers[fpath]
|
|
|
|
def close(self):
|
|
for f in self.loggers.values():
|
|
f.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
def trace_logger_enabled() -> bool:
|
|
return (
|
|
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
|
|
)
|
|
|
|
|
|
class TraceLogger:
|
|
_log_files = LogFiles()
|
|
|
|
def __init__(self, category, tid=None, file_prefix=None) -> None:
|
|
self.enabled = trace_logger_enabled()
|
|
|
|
if self.enabled:
|
|
self.pid = os.getpid()
|
|
self.logger = None
|
|
self.cat = category
|
|
self._traces = {}
|
|
self.global_rank = 0
|
|
if tid is None:
|
|
self.tid = MODULE_TID.get(category, 1000)
|
|
else:
|
|
self.tid = tid
|
|
self.file_prefix = file_prefix if file_prefix is not None else self.cat
|
|
|
|
self.registered_modules = []
|
|
|
|
def _creat_logger(self) -> None:
|
|
# delay creating logger file until first log call,
|
|
# since torch.distributed may not be ready yet
|
|
if torch.distributed.is_initialized():
|
|
self.global_rank = torch.distributed.get_rank()
|
|
self.logger = TraceLogger._log_files.get(
|
|
self.file_prefix, self.global_rank, self.pid
|
|
)
|
|
|
|
def begin_trace(self, name, memory=False) -> None:
|
|
if not self.enabled:
|
|
return
|
|
|
|
if self.logger is None:
|
|
self._creat_logger()
|
|
assert self.logger is not None
|
|
|
|
name = f"{name}" # convert it to str to ensure json serializable
|
|
|
|
start_time = int(datetime.now().timestamp() * 1e6) # in us
|
|
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
|
|
|
|
mem_trace = self._get_memory(start_time) if memory else None
|
|
|
|
if name not in self._traces:
|
|
self._traces[name] = [(trace, mem_trace)]
|
|
else: # in case call to the function is nested
|
|
self._traces[name].append((trace, mem_trace))
|
|
|
|
def end_trace(self, name, flush=False, memory=False) -> None:
|
|
if not self.enabled:
|
|
return
|
|
|
|
name = f"{name}" # convert it to str to ensure json serializable
|
|
|
|
assert self.logger is not None, "begin_trace should be called before end_trace"
|
|
assert name in self._traces, "begin_trace should be called before end_trace"
|
|
|
|
start_trace, start_mem = self._traces[name].pop()
|
|
if start_mem is not None:
|
|
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
|
|
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
|
|
|
|
end_time = int(datetime.now().timestamp() * 1e6) # in us
|
|
args = {"value(us)": end_time - start_trace.ts}
|
|
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
|
|
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
|
|
|
|
if memory:
|
|
mem_trace = self._get_memory(end_time)
|
|
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
|
|
|
|
if flush:
|
|
self.flush()
|
|
|
|
def flush(self) -> None:
|
|
if self.logger is not None:
|
|
self.logger.flush()
|
|
|
|
def _get_memory(self, timestamp):
|
|
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
|
|
mem_trace = TraceEntry(
|
|
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
|
|
)
|
|
return mem_trace
|
|
|
|
|
|
@contextmanager
|
|
def _trace_time(name, logger_inst, memory=False, flush=False):
|
|
if not logger_inst.enabled:
|
|
yield
|
|
return
|
|
|
|
logger_inst.begin_trace(name)
|
|
yield
|
|
logger_inst.end_trace(name, flush=flush)
|
|
|
|
|
|
SKIPED_MODULES = []
|
|
|
|
|
|
def _register_module_trace(
|
|
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
|
|
):
|
|
if not logger_inst.enabled:
|
|
return
|
|
|
|
if not isinstance(module, torch.nn.Module):
|
|
return
|
|
|
|
def _register(m):
|
|
module_name = f"{type(m).__name__}"
|
|
if module_name == "WrapName":
|
|
module_name = f"{type(m.forward_func.__self__).__name__}"
|
|
|
|
if module_name in SKIPED_MODULES:
|
|
return
|
|
|
|
forward_name = module_name + ".forward"
|
|
|
|
m.register_forward_pre_hook(
|
|
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
|
|
)
|
|
m.register_forward_hook(
|
|
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
|
|
)
|
|
|
|
if not forward_only:
|
|
backward_name = module_name + ".backward"
|
|
|
|
m.register_full_backward_pre_hook(
|
|
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
|
|
)
|
|
m.register_full_backward_hook(
|
|
lambda m, grad_in, grad_out: logger_inst.end_trace(
|
|
backward_name, memory=True, flush=flush
|
|
)
|
|
)
|
|
|
|
for m in module.modules():
|
|
if m in logger_inst.registered_modules:
|
|
print(
|
|
f"module `{m}` already registered, skip applying trace on same module multiple times."
|
|
)
|
|
continue
|
|
_register(m)
|
|
|
|
|
|
def _trace_autograd_function(logger_inst):
|
|
def decorator(cls):
|
|
if not issubclass(cls, torch.autograd.Function):
|
|
return cls
|
|
|
|
def _apply(name, method):
|
|
def wrapper(*args, **kwargs):
|
|
with _trace_time(name, logger_inst=logger_inst, memory=True):
|
|
result = method(*args, **kwargs)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
for attr in ["forward", "backward"]:
|
|
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
|
|
return cls
|
|
|
|
return decorator
|
|
|
|
|
|
def _register_optimizer_trace(
|
|
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
|
|
):
|
|
if not logger_inst.enabled:
|
|
return
|
|
|
|
trace_name = f"{type(optimizer).__name__}.step"
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer):
|
|
optimizer.register_step_pre_hook(
|
|
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
|
|
)
|
|
optimizer.register_step_post_hook(
|
|
lambda m, *args, **kwargs: logger_inst.end_trace(
|
|
trace_name, memory=True, flush=flush
|
|
)
|
|
)
|
|
elif hasattr(optimizer, "step") and callable(optimizer.step):
|
|
# customized optimzier does not has step hooks
|
|
|
|
original_step = optimizer.step
|
|
|
|
def traced_step(*args, **kwargs):
|
|
logger_inst.begin_trace(trace_name, memory=True)
|
|
result = original_step(*args, **kwargs)
|
|
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
|
return result
|
|
|
|
# Replace the step method with the new function
|
|
optimizer.step = traced_step
|
|
else:
|
|
# unknown optimizer or wrong instance pass to this function.
|
|
pass
|
|
|
|
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
|
|
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
|
|
original_reduce = optimizer.reduce_gradients
|
|
|
|
def traced_reduce(*args, **kwargs):
|
|
logger_inst.begin_trace(trace_name, memory=True)
|
|
result = original_reduce(*args, **kwargs)
|
|
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
|
return result
|
|
|
|
# Replace the step method with the new function
|
|
optimizer.reduce_gradients = traced_reduce
|
|
|
|
|
|
def get_trace_api(name="nn.Module"):
|
|
"""generate module execution trace APIs for a given module name
|
|
|
|
Args:
|
|
name (str): module name
|
|
|
|
Returns:
|
|
tuple: (trace_time, register_module_trace, trace_autograd_function)
|
|
Usage of these three functions is describted in the docstring of this module
|
|
"""
|
|
_trace_logger = TraceLogger(name)
|
|
|
|
return (
|
|
partial(_trace_time, logger_inst=_trace_logger),
|
|
partial(_register_module_trace, logger_inst=_trace_logger),
|
|
partial(_trace_autograd_function, logger_inst=_trace_logger),
|
|
partial(_register_optimizer_trace, logger_inst=_trace_logger),
|
|
)
|