This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

31
vacc_tools/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
from functools import partial
from datetime import datetime
from typing import Union, Tuple
import torch
import torch.distributed
_module_time = {}
def print_module_time(
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
):
def now_as_us():
return int(datetime.now().timestamp() * 1e6) # in us
def _pre_forward(suffix, m, inputs):
name = f"{type(m).__name__}.{suffix}"
_module_time[name] = now_as_us()
def _post_forward(suffix, m, inputs, outputs):
name = f"{type(m).__name__}.{suffix}"
start_time = _module_time.pop(name)
print(f"{name}: {now_as_us() - start_time} us")
for name, m in model.named_modules():
if isinstance(m, module):
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
m.register_forward_hook(partial(_post_forward, "forward"))
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
m.register_full_backward_hook(partial(_post_forward, "backward"))

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,214 @@
"""Generating tracing json files from log files.
Usage:
python -m vacc_tools.generate_trace --log-dir <directory of log files> --out-file-prefix <prefix of output file>
"""
import argparse
import json
import os
import re
import numpy as np
import tabulate
from glob import glob
from collections import defaultdict
from multiprocessing import Pool
def run_stats_on_traces(timelines):
op_cat_list = ["ODSP", "DLC", "VCCL", "CPU", "CPU_OP"]
op_stats = {op: {} for op in op_cat_list}
for line in timelines:
if '"E"' not in line: # optim 3, skip everything if not `"E"`
continue
# optim 2: using `[:-2]` instead of replace()
line = line[:-2] # remove ',\n'
try:
values = json.loads(line)
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
continue
if values["ph"] == "E" and values["cat"] in op_cat_list:
cat = values["cat"]
if values["name"] not in op_stats[cat]:
op_stats[cat][values["name"]] = []
if "dur" in values["args"]:
# optim 1: using `[:-2]` instead of replace()
op_stats[cat][values["name"]].append(
int(values["args"]["dur"][:-2]) # strip `us`
)
elif "values(us)" in values["args"]:
op_stats[cat][values["name"]].append(values["args"]["value(us)"])
op_tables = {}
for cat, stats in op_stats.items():
# optim 4: using list comprehension instead of for loop
table = []
for name, dur in stats.items():
dur = np.array(dur)
t = [
name,
np.min(dur),
np.max(dur),
np.sum(dur),
np.mean(dur),
np.percentile(dur, 90),
len(dur),
]
table.append(t)
table = sorted(table, key=lambda x: x[-1], reverse=True)
op_tables[cat] = tabulate.tabulate(
table,
headers=["op", "min", "max", "sum", "avg", "p90", "count"],
tablefmt="plain",
)
if cat in ["VCCL", "ODSP", "DLC"]:
op_tables["VACC-ALL"] = op_tables.get("VACC-ALL", []) + [
t + [cat] for t in table
]
total = sum([x[3] for x in op_tables["VACC-ALL"]])
op_tables["VACC-ALL"] = [t + [t[3] / total * 100] for t in op_tables["VACC-ALL"]]
op_tables["VACC-ALL"] = tabulate.tabulate(
sorted(op_tables["VACC-ALL"], key=lambda x: x[-1], reverse=True),
headers=["op", "min", "max", "sum", "avg", "p90", "count", "cat", "percent(%)"],
tablefmt="plain",
)
return op_tables
def get_rank_info(files):
# using pattern rank-<rank> in file name to get rank
for fpath in files:
rank = re.findall(r"rank-(\d+)", fpath)
if rank:
return int(rank[0])
return 0
def extract_traces(arg):
files, target_file_path, group_name, trace_token = arg
entries = [
(0, "scheduler"),
(1, "megatron"),
(2, "deepspeed"),
(3, "nn.Module"),
(10, "vacc-odsp"),
(11, "vacc-dlc"),
(12, "vacc-vccl"),
(13, "vacc-cpu"),
(14, "vacc-fallback"),
(15, "vacc-ddr"),
(20, "lib-vccl"),
]
with open(target_file_path, "w", encoding="utf-8") as trace_file:
trace_file.write("[")
for tid, thread_name in entries:
line = f'{{"cat":"__metadata","pid":{group_name},"tid":{tid},"ts":0,"ph":"M","name":"thread_name","args":{{"name":"{thread_name}"}}}},\n'
trace_file.write(line)
timelines = []
for fpath in files:
with open(fpath, "r", encoding="utf-8") as file:
# timelines += [line.split(trace_token)[1] for line in file if trace_token in line]
for line in file:
if trace_token in line:
# 找到目标字符串,取其之后的内容(包括目标字符串)
timelines.append(line.split(trace_token)[1])
try:
json.loads(timelines[-1][:-2]) # remove ',\n'
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
# chrome:://tracing stops reading following lines if an error encountered
# so must remove lines with error
timelines.pop()
for line in timelines[:-1]:
trace_file.write(line)
# fixing JSON format error by removing last comma in a list
trace_file.write(timelines[-1].replace(",\n", "\n"))
trace_file.write("]")
op_stats = run_stats_on_traces(timelines)
with open(
target_file_path.replace(".json", ".txt"), "w", encoding="utf-8"
) as op_stats_file:
for cat, tables in op_stats.items():
op_stats_file.write(f"{cat}".center(80, "-") + "\n")
op_stats_file.write(tables + "\n\n")
def merge_schedule(out_file_prefix):
scheduler_data = []
for file in glob(f"{out_file_prefix}*.json"):
if file.endswith("schedule.json"):
continue
assert "rank" in file
rank = file.split("rank_")[-1].split("_")[0]
pid = None
with open(file, "r", encoding="utf-8") as f:
for line in f:
# set all schedule's pid to 0 and set all schedule's tid to rank id
if '"tid":0,' in line and "__metadata" not in line:
if pid is None:
pid = line.split('"pid":')[1].split(",")[0]
line = line.replace(f'"pid":{pid}', f'"pid":0')
line = line.replace('"tid":0,', f'"tid":{rank},')
scheduler_data.append(line)
out_file = f"{out_file_prefix}schedule.json"
with open(out_file, "w", encoding="utf-8") as f:
f.write("[\n")
f.writelines(scheduler_data[:-1])
f.write(scheduler_data[-1].replace(",\n", "\n"))
f.write("]\n")
def scan_and_generate_trace(args, trace_token):
grouped_files = defaultdict(list)
for root, dirs, files in os.walk(args.log_dir):
for filename in files:
fpath = os.path.join(root, filename)
file_size = os.path.getsize(fpath)
if file_size != 0:
group_name = filename.rsplit("_", 1)[1].split(".")[0]
grouped_files[group_name].append(fpath)
pool_args = []
for group_name, files in grouped_files.items():
rank = get_rank_info(files)
out_file = f"{args.out_file_prefix}rank_{rank}_{group_name}.json"
pool_args.append((files, out_file, group_name, trace_token))
with Pool(len(grouped_files)) as p:
p.map(extract_traces, pool_args)
if args.merge_schedule:
merge_schedule(args.out_file_prefix)
if __name__ == "__main__":
TRACE_TOKEN = "LOG_TRACE:"
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(os.path.dirname(current_file_path))
find_directory = os.path.join(parent_directory, "log")
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-dir", default=find_directory, type=str, help="directory of log files"
)
parser.add_argument("--out-file-prefix", default="timeline_", type=str)
parser.add_argument("--merge-schedule", action="store_true")
args = parser.parse_args()
scan_and_generate_trace(args, TRACE_TOKEN)
print("Scan and trace generation done!")

View File

@@ -0,0 +1,151 @@
from contextlib import contextmanager
from dataclasses import fields
from typing import Dict, Tuple, List, Optional
import torch
NUM_BYTES_IN_MB = 1024**2
NUM_BYTES_IN_GB = 1024**3
class MemoryAnalyzer:
def __init__(
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
):
"""This memory usage analyzer will be mostly acurate only if you initialize
at the beginning and insert `get_memory_usage_in_gb` at the end of your
forward pass.
NOTE: It will have negative impact if not properly used as it stores
activations of every nn.Module's forward function and relies on user to
reset it everytime the forward pass ends.
Limitations:
1. does not work with customized operators
2. does not work with functional operators
3. it approximates activation as nn.Module.forward's output (if it's
inside the graph requires gradients), so it may not be exactly accurate.
"""
self.model = model
self.optimizer = optimizer
self.activ_addrs = set()
self.activ_memory = 0
@staticmethod
def _is_activation(x):
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
def _get_weight_grads_addrs(self):
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
grads = set(
[
p.grad.untyped_storage().data_ptr()
for p in self.model.parameters()
if p.grad is not None
]
)
return weights.union(grads)
def pack_hook(self):
def _pack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
self.activ_addrs.add(data_ptr)
self.activ_memory += x.untyped_storage().size()
return x
return _pack_hook
def unpack_hook(self):
def _unpack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
self.activ_addrs.remove(data_ptr)
self.activ_memory -= x.untyped_storage().size()
return x
return _unpack_hook
@contextmanager
def record_activation(self):
with torch.autograd.graph.saved_tensors_hooks(
self.pack_hook(), self.unpack_hook()
):
yield
@staticmethod
def get_weight_memory(model: torch.nn.Module):
weights = [
p.nelement() * p.element_size()
for p in model.parameters()
if p.device != "cpu"
]
return sum(weights)
@staticmethod
def get_gradient_memory(model: torch.nn.Module):
grads = [
p.grad.nelement() * p.grad.element_size()
for p in model.parameters()
if p.grad is not None and p.grad.device != "cpu"
]
return sum(grads)
def _sum_activation_memory(self):
return self.activ_memory
def get_optimizer_state_memory(self):
if isinstance(self.optimizer, torch.optim.AdamW):
params = sum(
[
p.nelement() * p.element_size()
for pg in self.optimizer.param_groups
for p in pg["params"]
if torch.is_tensor(p) and p.device != "cpu"
]
)
for state in self.optimizer.state.values():
params += sum(
[
v.nelement() * v.element_size()
for k, v in state.items()
if torch.is_tensor(v) and v.device != "cpu"
]
)
return params
return 0
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
return (
self.get_weight_memory(self.model),
self.get_gradient_memory(self.model),
self._sum_activation_memory(),
self.get_optimizer_state_memory(),
)
def get_memory_usage_in_gb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
)
def get_memory_usage_in_mb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
)

View File

@@ -0,0 +1,65 @@
import argparse
import os
from collections import defaultdict
from multiprocessing import Pool
log_tag = "LOG_TRACE:"
tid_names = [
(0, "module"),
(1, "megatron"),
(2, "deepspeed"),
(10, "vacc-odsp"),
(11, "vacc-dlc"),
(12, "vacc-vccl"),
(13, "vacc-cpu"),
(14, "vacc-cpu_fallback"),
(15, "vacc-ddr"),
(20, "lib-vccl"),
]
def parse_files_of_process(args):
pid, in_files = args
out_file = "trace_" + pid + ".json"
with open(out_file, "w", encoding="utf-8") as new_file:
metadata_lines = [
f'{{"name": "thread_name","ph": "M","pid": {pid},"tid": {tid},"args": {{"name": "{name}"}}}},'
for tid, name in tid_names
]
new_file.write("[\n")
new_file.write("\n".join(metadata_lines))
new_file.write("\n")
for file_path in in_files:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if log_tag in line:
new_line = line.split(log_tag, 1)[1].strip()
new_file.write(new_line + "\n")
new_file.write("]")
def parse_directory(directory):
pro_files = defaultdict(list)
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
if filename.startswith("vacc") and os.path.getsize(file_path) != 0:
pid = filename.rsplit("_", 1)[1].split(".")[0]
pro_files[pid].append(file_path)
args = []
for pid, in_files in pro_files.items():
args.append((pid, in_files))
with Pool() as p:
p.map(parse_files_of_process, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="parse vacc log files and generate trace files"
)
parser.add_argument("directory", type=str, help="log directory to parse")
args = parser.parse_args()
parse_directory(args.directory)

329
vacc_tools/trace_logger.py Normal file
View File

@@ -0,0 +1,329 @@
"""
This module provides mechanisms for tracing torch's module and function's execution,
and output the trace into a json file.
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
If not, no trace will be applied.
Inside your module, create your module's tracer functions by using `get_trace_api`.
You will get three functions:
* `@trace_time(name)`: decorator to trace the execution of a function.
```python
@trace_time("my_func")
def my_func(x):
...
```
* `@trace_autograd_function()`: decorator to trace the execution of forward
and backward of a user defined `torch.autograd.Function` operator.
```python
@trace_autograd_function()
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
...
@staticmethod
def backward(ctx, grad_output):
...
```
* `register_module_trace()`: function to register trace a model (`nn.Module`),
it applies traces recursively to a torch model by enumerating all nn.Module
and register tracer to their forward and backward function. Only applying to
top level nn.Module is recommended.
```python
model = Model()
register_module_trace(model)
```
"""
import os
import json
from contextlib import contextmanager
from dataclasses import dataclass, asdict
from datetime import datetime
from functools import partial
import torch
import torch.distributed
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
# pylint: disable=missing-docstring
@dataclass
class TraceEntry:
name: str
cat: str
pid: int
tid: int
ts: int
ph: str
args: str = None
def to_json_str(self):
d = asdict(self)
if self.args is None:
d.pop("args")
return json.dumps(d, separators=(",", ": "))
class LogFiles:
def __init__(self) -> None:
self.loggers = {}
def get(self, file_prefix, rank, pid):
os.makedirs("log", exist_ok=True)
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
if not fpath in self.loggers:
self.loggers[fpath] = open(fpath, "w")
return self.loggers[fpath]
def close(self):
for f in self.loggers.values():
f.close()
def __del__(self):
self.close()
def trace_logger_enabled() -> bool:
return (
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
)
class TraceLogger:
_log_files = LogFiles()
def __init__(self, category, tid=None, file_prefix=None) -> None:
self.enabled = trace_logger_enabled()
if self.enabled:
self.pid = os.getpid()
self.logger = None
self.cat = category
self._traces = {}
self.global_rank = 0
if tid is None:
self.tid = MODULE_TID.get(category, 1000)
else:
self.tid = tid
self.file_prefix = file_prefix if file_prefix is not None else self.cat
self.registered_modules = []
def _creat_logger(self) -> None:
# delay creating logger file until first log call,
# since torch.distributed may not be ready yet
if torch.distributed.is_initialized():
self.global_rank = torch.distributed.get_rank()
self.logger = TraceLogger._log_files.get(
self.file_prefix, self.global_rank, self.pid
)
def begin_trace(self, name, memory=False) -> None:
if not self.enabled:
return
if self.logger is None:
self._creat_logger()
assert self.logger is not None
name = f"{name}" # convert it to str to ensure json serializable
start_time = int(datetime.now().timestamp() * 1e6) # in us
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
mem_trace = self._get_memory(start_time) if memory else None
if name not in self._traces:
self._traces[name] = [(trace, mem_trace)]
else: # in case call to the function is nested
self._traces[name].append((trace, mem_trace))
def end_trace(self, name, flush=False, memory=False) -> None:
if not self.enabled:
return
name = f"{name}" # convert it to str to ensure json serializable
assert self.logger is not None, "begin_trace should be called before end_trace"
assert name in self._traces, "begin_trace should be called before end_trace"
start_trace, start_mem = self._traces[name].pop()
if start_mem is not None:
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
end_time = int(datetime.now().timestamp() * 1e6) # in us
args = {"value(us)": end_time - start_trace.ts}
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
if memory:
mem_trace = self._get_memory(end_time)
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
if flush:
self.flush()
def flush(self) -> None:
if self.logger is not None:
self.logger.flush()
def _get_memory(self, timestamp):
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
mem_trace = TraceEntry(
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
)
return mem_trace
@contextmanager
def _trace_time(name, logger_inst, memory=False, flush=False):
if not logger_inst.enabled:
yield
return
logger_inst.begin_trace(name)
yield
logger_inst.end_trace(name, flush=flush)
SKIPED_MODULES = []
def _register_module_trace(
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
):
if not logger_inst.enabled:
return
if not isinstance(module, torch.nn.Module):
return
def _register(m):
module_name = f"{type(m).__name__}"
if module_name == "WrapName":
module_name = f"{type(m.forward_func.__self__).__name__}"
if module_name in SKIPED_MODULES:
return
forward_name = module_name + ".forward"
m.register_forward_pre_hook(
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
)
m.register_forward_hook(
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
)
if not forward_only:
backward_name = module_name + ".backward"
m.register_full_backward_pre_hook(
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
)
m.register_full_backward_hook(
lambda m, grad_in, grad_out: logger_inst.end_trace(
backward_name, memory=True, flush=flush
)
)
for m in module.modules():
if m in logger_inst.registered_modules:
print(
f"module `{m}` already registered, skip applying trace on same module multiple times."
)
continue
_register(m)
def _trace_autograd_function(logger_inst):
def decorator(cls):
if not issubclass(cls, torch.autograd.Function):
return cls
def _apply(name, method):
def wrapper(*args, **kwargs):
with _trace_time(name, logger_inst=logger_inst, memory=True):
result = method(*args, **kwargs)
return result
return wrapper
for attr in ["forward", "backward"]:
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
return cls
return decorator
def _register_optimizer_trace(
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
):
if not logger_inst.enabled:
return
trace_name = f"{type(optimizer).__name__}.step"
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.register_step_pre_hook(
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
)
optimizer.register_step_post_hook(
lambda m, *args, **kwargs: logger_inst.end_trace(
trace_name, memory=True, flush=flush
)
)
elif hasattr(optimizer, "step") and callable(optimizer.step):
# customized optimzier does not has step hooks
original_step = optimizer.step
def traced_step(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_step(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.step = traced_step
else:
# unknown optimizer or wrong instance pass to this function.
pass
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
original_reduce = optimizer.reduce_gradients
def traced_reduce(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_reduce(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.reduce_gradients = traced_reduce
def get_trace_api(name="nn.Module"):
"""generate module execution trace APIs for a given module name
Args:
name (str): module name
Returns:
tuple: (trace_time, register_module_trace, trace_autograd_function)
Usage of these three functions is describted in the docstring of this module
"""
_trace_logger = TraceLogger(name)
return (
partial(_trace_time, logger_inst=_trace_logger),
partial(_register_module_trace, logger_inst=_trace_logger),
partial(_trace_autograd_function, logger_inst=_trace_logger),
partial(_register_optimizer_trace, logger_inst=_trace_logger),
)