""" This module provides mechanisms for tracing torch's module and function's execution, and output the trace into a json file. User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing. If not, no trace will be applied. Inside your module, create your module's tracer functions by using `get_trace_api`. You will get three functions: * `@trace_time(name)`: decorator to trace the execution of a function. ```python @trace_time("my_func") def my_func(x): ... ``` * `@trace_autograd_function()`: decorator to trace the execution of forward and backward of a user defined `torch.autograd.Function` operator. ```python @trace_autograd_function() class MyAutogradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): ... @staticmethod def backward(ctx, grad_output): ... ``` * `register_module_trace()`: function to register trace a model (`nn.Module`), it applies traces recursively to a torch model by enumerating all nn.Module and register tracer to their forward and backward function. Only applying to top level nn.Module is recommended. ```python model = Model() register_module_trace(model) ``` """ import os import json from contextlib import contextmanager from dataclasses import dataclass, asdict from datetime import datetime from functools import partial import torch import torch.distributed MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100} # pylint: disable=missing-docstring @dataclass class TraceEntry: name: str cat: str pid: int tid: int ts: int ph: str args: str = None def to_json_str(self): d = asdict(self) if self.args is None: d.pop("args") return json.dumps(d, separators=(",", ": ")) class LogFiles: def __init__(self) -> None: self.loggers = {} def get(self, file_prefix, rank, pid): os.makedirs("log", exist_ok=True) fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt" if not fpath in self.loggers: self.loggers[fpath] = open(fpath, "w") return self.loggers[fpath] def close(self): for f in self.loggers.values(): f.close() def __del__(self): self.close() def trace_logger_enabled() -> bool: return ( "LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1" ) class TraceLogger: _log_files = LogFiles() def __init__(self, category, tid=None, file_prefix=None) -> None: self.enabled = trace_logger_enabled() if self.enabled: self.pid = os.getpid() self.logger = None self.cat = category self._traces = {} self.global_rank = 0 if tid is None: self.tid = MODULE_TID.get(category, 1000) else: self.tid = tid self.file_prefix = file_prefix if file_prefix is not None else self.cat self.registered_modules = [] def _creat_logger(self) -> None: # delay creating logger file until first log call, # since torch.distributed may not be ready yet if torch.distributed.is_initialized(): self.global_rank = torch.distributed.get_rank() self.logger = TraceLogger._log_files.get( self.file_prefix, self.global_rank, self.pid ) def begin_trace(self, name, memory=False) -> None: if not self.enabled: return if self.logger is None: self._creat_logger() assert self.logger is not None name = f"{name}" # convert it to str to ensure json serializable start_time = int(datetime.now().timestamp() * 1e6) # in us trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B") mem_trace = self._get_memory(start_time) if memory else None if name not in self._traces: self._traces[name] = [(trace, mem_trace)] else: # in case call to the function is nested self._traces[name].append((trace, mem_trace)) def end_trace(self, name, flush=False, memory=False) -> None: if not self.enabled: return name = f"{name}" # convert it to str to ensure json serializable assert self.logger is not None, "begin_trace should be called before end_trace" assert name in self._traces, "begin_trace should be called before end_trace" start_trace, start_mem = self._traces[name].pop() if start_mem is not None: self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n") self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n") end_time = int(datetime.now().timestamp() * 1e6) # in us args = {"value(us)": end_time - start_trace.ts} trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args) self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n") if memory: mem_trace = self._get_memory(end_time) self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n") if flush: self.flush() def flush(self) -> None: if self.logger is not None: self.logger.flush() def _get_memory(self, timestamp): args = {"value": torch.vacc.memory_allocated(self.global_rank)} mem_trace = TraceEntry( "memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args ) return mem_trace @contextmanager def _trace_time(name, logger_inst, memory=False, flush=False): if not logger_inst.enabled: yield return logger_inst.begin_trace(name) yield logger_inst.end_trace(name, flush=flush) SKIPED_MODULES = [] def _register_module_trace( module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False ): if not logger_inst.enabled: return if not isinstance(module, torch.nn.Module): return def _register(m): module_name = f"{type(m).__name__}" if module_name == "WrapName": module_name = f"{type(m.forward_func.__self__).__name__}" if module_name in SKIPED_MODULES: return forward_name = module_name + ".forward" m.register_forward_pre_hook( lambda m, inp: logger_inst.begin_trace(forward_name, memory=True) ) m.register_forward_hook( lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True) ) if not forward_only: backward_name = module_name + ".backward" m.register_full_backward_pre_hook( lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True) ) m.register_full_backward_hook( lambda m, grad_in, grad_out: logger_inst.end_trace( backward_name, memory=True, flush=flush ) ) for m in module.modules(): if m in logger_inst.registered_modules: print( f"module `{m}` already registered, skip applying trace on same module multiple times." ) continue _register(m) def _trace_autograd_function(logger_inst): def decorator(cls): if not issubclass(cls, torch.autograd.Function): return cls def _apply(name, method): def wrapper(*args, **kwargs): with _trace_time(name, logger_inst=logger_inst, memory=True): result = method(*args, **kwargs) return result return wrapper for attr in ["forward", "backward"]: setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr))) return cls return decorator def _register_optimizer_trace( optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True ): if not logger_inst.enabled: return trace_name = f"{type(optimizer).__name__}.step" if isinstance(optimizer, torch.optim.Optimizer): optimizer.register_step_pre_hook( lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True) ) optimizer.register_step_post_hook( lambda m, *args, **kwargs: logger_inst.end_trace( trace_name, memory=True, flush=flush ) ) elif hasattr(optimizer, "step") and callable(optimizer.step): # customized optimzier does not has step hooks original_step = optimizer.step def traced_step(*args, **kwargs): logger_inst.begin_trace(trace_name, memory=True) result = original_step(*args, **kwargs) logger_inst.end_trace(trace_name, memory=True, flush=flush) return result # Replace the step method with the new function optimizer.step = traced_step else: # unknown optimizer or wrong instance pass to this function. pass if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients): trace_name = f"{type(optimizer).__name__}.reduce_gradients" original_reduce = optimizer.reduce_gradients def traced_reduce(*args, **kwargs): logger_inst.begin_trace(trace_name, memory=True) result = original_reduce(*args, **kwargs) logger_inst.end_trace(trace_name, memory=True, flush=flush) return result # Replace the step method with the new function optimizer.reduce_gradients = traced_reduce def get_trace_api(name="nn.Module"): """generate module execution trace APIs for a given module name Args: name (str): module name Returns: tuple: (trace_time, register_module_trace, trace_autograd_function) Usage of these three functions is describted in the docstring of this module """ _trace_logger = TraceLogger(name) return ( partial(_trace_time, logger_inst=_trace_logger), partial(_register_module_trace, logger_inst=_trace_logger), partial(_trace_autograd_function, logger_inst=_trace_logger), partial(_register_optimizer_trace, logger_inst=_trace_logger), )