Files
2026-04-02 04:55:00 +00:00

330 lines
10 KiB
Python

"""
This module provides mechanisms for tracing torch's module and function's execution,
and output the trace into a json file.
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
If not, no trace will be applied.
Inside your module, create your module's tracer functions by using `get_trace_api`.
You will get three functions:
* `@trace_time(name)`: decorator to trace the execution of a function.
```python
@trace_time("my_func")
def my_func(x):
...
```
* `@trace_autograd_function()`: decorator to trace the execution of forward
and backward of a user defined `torch.autograd.Function` operator.
```python
@trace_autograd_function()
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
...
@staticmethod
def backward(ctx, grad_output):
...
```
* `register_module_trace()`: function to register trace a model (`nn.Module`),
it applies traces recursively to a torch model by enumerating all nn.Module
and register tracer to their forward and backward function. Only applying to
top level nn.Module is recommended.
```python
model = Model()
register_module_trace(model)
```
"""
import os
import json
from contextlib import contextmanager
from dataclasses import dataclass, asdict
from datetime import datetime
from functools import partial
import torch
import torch.distributed
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
# pylint: disable=missing-docstring
@dataclass
class TraceEntry:
name: str
cat: str
pid: int
tid: int
ts: int
ph: str
args: str = None
def to_json_str(self):
d = asdict(self)
if self.args is None:
d.pop("args")
return json.dumps(d, separators=(",", ": "))
class LogFiles:
def __init__(self) -> None:
self.loggers = {}
def get(self, file_prefix, rank, pid):
os.makedirs("log", exist_ok=True)
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
if not fpath in self.loggers:
self.loggers[fpath] = open(fpath, "w")
return self.loggers[fpath]
def close(self):
for f in self.loggers.values():
f.close()
def __del__(self):
self.close()
def trace_logger_enabled() -> bool:
return (
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
)
class TraceLogger:
_log_files = LogFiles()
def __init__(self, category, tid=None, file_prefix=None) -> None:
self.enabled = trace_logger_enabled()
if self.enabled:
self.pid = os.getpid()
self.logger = None
self.cat = category
self._traces = {}
self.global_rank = 0
if tid is None:
self.tid = MODULE_TID.get(category, 1000)
else:
self.tid = tid
self.file_prefix = file_prefix if file_prefix is not None else self.cat
self.registered_modules = []
def _creat_logger(self) -> None:
# delay creating logger file until first log call,
# since torch.distributed may not be ready yet
if torch.distributed.is_initialized():
self.global_rank = torch.distributed.get_rank()
self.logger = TraceLogger._log_files.get(
self.file_prefix, self.global_rank, self.pid
)
def begin_trace(self, name, memory=False) -> None:
if not self.enabled:
return
if self.logger is None:
self._creat_logger()
assert self.logger is not None
name = f"{name}" # convert it to str to ensure json serializable
start_time = int(datetime.now().timestamp() * 1e6) # in us
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
mem_trace = self._get_memory(start_time) if memory else None
if name not in self._traces:
self._traces[name] = [(trace, mem_trace)]
else: # in case call to the function is nested
self._traces[name].append((trace, mem_trace))
def end_trace(self, name, flush=False, memory=False) -> None:
if not self.enabled:
return
name = f"{name}" # convert it to str to ensure json serializable
assert self.logger is not None, "begin_trace should be called before end_trace"
assert name in self._traces, "begin_trace should be called before end_trace"
start_trace, start_mem = self._traces[name].pop()
if start_mem is not None:
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
end_time = int(datetime.now().timestamp() * 1e6) # in us
args = {"value(us)": end_time - start_trace.ts}
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
if memory:
mem_trace = self._get_memory(end_time)
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
if flush:
self.flush()
def flush(self) -> None:
if self.logger is not None:
self.logger.flush()
def _get_memory(self, timestamp):
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
mem_trace = TraceEntry(
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
)
return mem_trace
@contextmanager
def _trace_time(name, logger_inst, memory=False, flush=False):
if not logger_inst.enabled:
yield
return
logger_inst.begin_trace(name)
yield
logger_inst.end_trace(name, flush=flush)
SKIPED_MODULES = []
def _register_module_trace(
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
):
if not logger_inst.enabled:
return
if not isinstance(module, torch.nn.Module):
return
def _register(m):
module_name = f"{type(m).__name__}"
if module_name == "WrapName":
module_name = f"{type(m.forward_func.__self__).__name__}"
if module_name in SKIPED_MODULES:
return
forward_name = module_name + ".forward"
m.register_forward_pre_hook(
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
)
m.register_forward_hook(
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
)
if not forward_only:
backward_name = module_name + ".backward"
m.register_full_backward_pre_hook(
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
)
m.register_full_backward_hook(
lambda m, grad_in, grad_out: logger_inst.end_trace(
backward_name, memory=True, flush=flush
)
)
for m in module.modules():
if m in logger_inst.registered_modules:
print(
f"module `{m}` already registered, skip applying trace on same module multiple times."
)
continue
_register(m)
def _trace_autograd_function(logger_inst):
def decorator(cls):
if not issubclass(cls, torch.autograd.Function):
return cls
def _apply(name, method):
def wrapper(*args, **kwargs):
with _trace_time(name, logger_inst=logger_inst, memory=True):
result = method(*args, **kwargs)
return result
return wrapper
for attr in ["forward", "backward"]:
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
return cls
return decorator
def _register_optimizer_trace(
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
):
if not logger_inst.enabled:
return
trace_name = f"{type(optimizer).__name__}.step"
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.register_step_pre_hook(
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
)
optimizer.register_step_post_hook(
lambda m, *args, **kwargs: logger_inst.end_trace(
trace_name, memory=True, flush=flush
)
)
elif hasattr(optimizer, "step") and callable(optimizer.step):
# customized optimzier does not has step hooks
original_step = optimizer.step
def traced_step(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_step(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.step = traced_step
else:
# unknown optimizer or wrong instance pass to this function.
pass
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
original_reduce = optimizer.reduce_gradients
def traced_reduce(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_reduce(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.reduce_gradients = traced_reduce
def get_trace_api(name="nn.Module"):
"""generate module execution trace APIs for a given module name
Args:
name (str): module name
Returns:
tuple: (trace_time, register_module_trace, trace_autograd_function)
Usage of these three functions is describted in the docstring of this module
"""
_trace_logger = TraceLogger(name)
return (
partial(_trace_time, logger_inst=_trace_logger),
partial(_register_module_trace, logger_inst=_trace_logger),
partial(_trace_autograd_function, logger_inst=_trace_logger),
partial(_register_optimizer_trace, logger_inst=_trace_logger),
)