init
This commit is contained in:
31
vacc_tools/__init__.py
Normal file
31
vacc_tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from functools import partial
|
||||
from datetime import datetime
|
||||
from typing import Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
_module_time = {}
|
||||
|
||||
|
||||
def print_module_time(
|
||||
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
|
||||
):
|
||||
def now_as_us():
|
||||
return int(datetime.now().timestamp() * 1e6) # in us
|
||||
|
||||
def _pre_forward(suffix, m, inputs):
|
||||
name = f"{type(m).__name__}.{suffix}"
|
||||
_module_time[name] = now_as_us()
|
||||
|
||||
def _post_forward(suffix, m, inputs, outputs):
|
||||
name = f"{type(m).__name__}.{suffix}"
|
||||
start_time = _module_time.pop(name)
|
||||
print(f"{name}: {now_as_us() - start_time} us")
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, module):
|
||||
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
|
||||
m.register_forward_hook(partial(_post_forward, "forward"))
|
||||
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
|
||||
m.register_full_backward_hook(partial(_post_forward, "backward"))
|
||||
BIN
vacc_tools/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vacc_tools/__pycache__/generate_trace.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/generate_trace.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vacc_tools/__pycache__/memory_analyzer.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/memory_analyzer.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vacc_tools/__pycache__/trace_logger.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/trace_logger.cpython-312.pyc
Normal file
Binary file not shown.
214
vacc_tools/generate_trace.py
Normal file
214
vacc_tools/generate_trace.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Generating tracing json files from log files.
|
||||
|
||||
Usage:
|
||||
python -m vacc_tools.generate_trace --log-dir <directory of log files> --out-file-prefix <prefix of output file>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
import tabulate
|
||||
from glob import glob
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def run_stats_on_traces(timelines):
|
||||
op_cat_list = ["ODSP", "DLC", "VCCL", "CPU", "CPU_OP"]
|
||||
op_stats = {op: {} for op in op_cat_list}
|
||||
for line in timelines:
|
||||
if '"E"' not in line: # optim 3, skip everything if not `"E"`
|
||||
continue
|
||||
|
||||
# optim 2: using `[:-2]` instead of replace()
|
||||
line = line[:-2] # remove ',\n'
|
||||
try:
|
||||
values = json.loads(line)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# some log may not ends properly, just skip it
|
||||
continue
|
||||
|
||||
if values["ph"] == "E" and values["cat"] in op_cat_list:
|
||||
cat = values["cat"]
|
||||
if values["name"] not in op_stats[cat]:
|
||||
op_stats[cat][values["name"]] = []
|
||||
if "dur" in values["args"]:
|
||||
# optim 1: using `[:-2]` instead of replace()
|
||||
op_stats[cat][values["name"]].append(
|
||||
int(values["args"]["dur"][:-2]) # strip `us`
|
||||
)
|
||||
elif "values(us)" in values["args"]:
|
||||
op_stats[cat][values["name"]].append(values["args"]["value(us)"])
|
||||
op_tables = {}
|
||||
for cat, stats in op_stats.items():
|
||||
# optim 4: using list comprehension instead of for loop
|
||||
table = []
|
||||
for name, dur in stats.items():
|
||||
dur = np.array(dur)
|
||||
t = [
|
||||
name,
|
||||
np.min(dur),
|
||||
np.max(dur),
|
||||
np.sum(dur),
|
||||
np.mean(dur),
|
||||
np.percentile(dur, 90),
|
||||
len(dur),
|
||||
]
|
||||
table.append(t)
|
||||
|
||||
table = sorted(table, key=lambda x: x[-1], reverse=True)
|
||||
op_tables[cat] = tabulate.tabulate(
|
||||
table,
|
||||
headers=["op", "min", "max", "sum", "avg", "p90", "count"],
|
||||
tablefmt="plain",
|
||||
)
|
||||
|
||||
if cat in ["VCCL", "ODSP", "DLC"]:
|
||||
op_tables["VACC-ALL"] = op_tables.get("VACC-ALL", []) + [
|
||||
t + [cat] for t in table
|
||||
]
|
||||
|
||||
total = sum([x[3] for x in op_tables["VACC-ALL"]])
|
||||
op_tables["VACC-ALL"] = [t + [t[3] / total * 100] for t in op_tables["VACC-ALL"]]
|
||||
|
||||
op_tables["VACC-ALL"] = tabulate.tabulate(
|
||||
sorted(op_tables["VACC-ALL"], key=lambda x: x[-1], reverse=True),
|
||||
headers=["op", "min", "max", "sum", "avg", "p90", "count", "cat", "percent(%)"],
|
||||
tablefmt="plain",
|
||||
)
|
||||
|
||||
return op_tables
|
||||
|
||||
|
||||
def get_rank_info(files):
|
||||
# using pattern rank-<rank> in file name to get rank
|
||||
for fpath in files:
|
||||
rank = re.findall(r"rank-(\d+)", fpath)
|
||||
if rank:
|
||||
return int(rank[0])
|
||||
return 0
|
||||
|
||||
|
||||
def extract_traces(arg):
|
||||
files, target_file_path, group_name, trace_token = arg
|
||||
|
||||
entries = [
|
||||
(0, "scheduler"),
|
||||
(1, "megatron"),
|
||||
(2, "deepspeed"),
|
||||
(3, "nn.Module"),
|
||||
(10, "vacc-odsp"),
|
||||
(11, "vacc-dlc"),
|
||||
(12, "vacc-vccl"),
|
||||
(13, "vacc-cpu"),
|
||||
(14, "vacc-fallback"),
|
||||
(15, "vacc-ddr"),
|
||||
(20, "lib-vccl"),
|
||||
]
|
||||
|
||||
with open(target_file_path, "w", encoding="utf-8") as trace_file:
|
||||
trace_file.write("[")
|
||||
for tid, thread_name in entries:
|
||||
line = f'{{"cat":"__metadata","pid":{group_name},"tid":{tid},"ts":0,"ph":"M","name":"thread_name","args":{{"name":"{thread_name}"}}}},\n'
|
||||
trace_file.write(line)
|
||||
|
||||
timelines = []
|
||||
for fpath in files:
|
||||
with open(fpath, "r", encoding="utf-8") as file:
|
||||
# timelines += [line.split(trace_token)[1] for line in file if trace_token in line]
|
||||
for line in file:
|
||||
if trace_token in line:
|
||||
# 找到目标字符串,取其之后的内容(包括目标字符串)
|
||||
timelines.append(line.split(trace_token)[1])
|
||||
try:
|
||||
json.loads(timelines[-1][:-2]) # remove ',\n'
|
||||
except json.decoder.JSONDecodeError:
|
||||
# some log may not ends properly, just skip it
|
||||
# chrome:://tracing stops reading following lines if an error encountered
|
||||
# so must remove lines with error
|
||||
timelines.pop()
|
||||
|
||||
for line in timelines[:-1]:
|
||||
trace_file.write(line)
|
||||
# fixing JSON format error by removing last comma in a list
|
||||
trace_file.write(timelines[-1].replace(",\n", "\n"))
|
||||
trace_file.write("]")
|
||||
|
||||
op_stats = run_stats_on_traces(timelines)
|
||||
with open(
|
||||
target_file_path.replace(".json", ".txt"), "w", encoding="utf-8"
|
||||
) as op_stats_file:
|
||||
for cat, tables in op_stats.items():
|
||||
op_stats_file.write(f"{cat}".center(80, "-") + "\n")
|
||||
op_stats_file.write(tables + "\n\n")
|
||||
|
||||
|
||||
def merge_schedule(out_file_prefix):
|
||||
scheduler_data = []
|
||||
for file in glob(f"{out_file_prefix}*.json"):
|
||||
if file.endswith("schedule.json"):
|
||||
continue
|
||||
assert "rank" in file
|
||||
rank = file.split("rank_")[-1].split("_")[0]
|
||||
pid = None
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
# set all schedule's pid to 0 and set all schedule's tid to rank id
|
||||
if '"tid":0,' in line and "__metadata" not in line:
|
||||
if pid is None:
|
||||
pid = line.split('"pid":')[1].split(",")[0]
|
||||
|
||||
line = line.replace(f'"pid":{pid}', f'"pid":0')
|
||||
line = line.replace('"tid":0,', f'"tid":{rank},')
|
||||
scheduler_data.append(line)
|
||||
|
||||
out_file = f"{out_file_prefix}schedule.json"
|
||||
with open(out_file, "w", encoding="utf-8") as f:
|
||||
f.write("[\n")
|
||||
f.writelines(scheduler_data[:-1])
|
||||
f.write(scheduler_data[-1].replace(",\n", "\n"))
|
||||
f.write("]\n")
|
||||
|
||||
|
||||
def scan_and_generate_trace(args, trace_token):
|
||||
grouped_files = defaultdict(list)
|
||||
for root, dirs, files in os.walk(args.log_dir):
|
||||
for filename in files:
|
||||
fpath = os.path.join(root, filename)
|
||||
file_size = os.path.getsize(fpath)
|
||||
if file_size != 0:
|
||||
group_name = filename.rsplit("_", 1)[1].split(".")[0]
|
||||
grouped_files[group_name].append(fpath)
|
||||
pool_args = []
|
||||
for group_name, files in grouped_files.items():
|
||||
rank = get_rank_info(files)
|
||||
out_file = f"{args.out_file_prefix}rank_{rank}_{group_name}.json"
|
||||
pool_args.append((files, out_file, group_name, trace_token))
|
||||
|
||||
with Pool(len(grouped_files)) as p:
|
||||
p.map(extract_traces, pool_args)
|
||||
|
||||
if args.merge_schedule:
|
||||
merge_schedule(args.out_file_prefix)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TRACE_TOKEN = "LOG_TRACE:"
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
parent_directory = os.path.dirname(os.path.dirname(current_file_path))
|
||||
find_directory = os.path.join(parent_directory, "log")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--log-dir", default=find_directory, type=str, help="directory of log files"
|
||||
)
|
||||
parser.add_argument("--out-file-prefix", default="timeline_", type=str)
|
||||
parser.add_argument("--merge-schedule", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
scan_and_generate_trace(args, TRACE_TOKEN)
|
||||
print("Scan and trace generation done!")
|
||||
151
vacc_tools/memory_analyzer.py
Normal file
151
vacc_tools/memory_analyzer.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
import torch
|
||||
|
||||
NUM_BYTES_IN_MB = 1024**2
|
||||
NUM_BYTES_IN_GB = 1024**3
|
||||
|
||||
|
||||
class MemoryAnalyzer:
|
||||
def __init__(
|
||||
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
|
||||
):
|
||||
"""This memory usage analyzer will be mostly acurate only if you initialize
|
||||
at the beginning and insert `get_memory_usage_in_gb` at the end of your
|
||||
forward pass.
|
||||
|
||||
NOTE: It will have negative impact if not properly used as it stores
|
||||
activations of every nn.Module's forward function and relies on user to
|
||||
reset it everytime the forward pass ends.
|
||||
|
||||
Limitations:
|
||||
1. does not work with customized operators
|
||||
2. does not work with functional operators
|
||||
3. it approximates activation as nn.Module.forward's output (if it's
|
||||
inside the graph requires gradients), so it may not be exactly accurate.
|
||||
"""
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.activ_addrs = set()
|
||||
self.activ_memory = 0
|
||||
|
||||
@staticmethod
|
||||
def _is_activation(x):
|
||||
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
|
||||
|
||||
def _get_weight_grads_addrs(self):
|
||||
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
|
||||
grads = set(
|
||||
[
|
||||
p.grad.untyped_storage().data_ptr()
|
||||
for p in self.model.parameters()
|
||||
if p.grad is not None
|
||||
]
|
||||
)
|
||||
return weights.union(grads)
|
||||
|
||||
def pack_hook(self):
|
||||
def _pack_hook(x):
|
||||
if self._is_activation(x):
|
||||
weight_grads = self._get_weight_grads_addrs()
|
||||
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
|
||||
data_ptr = x.untyped_storage().data_ptr()
|
||||
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
|
||||
self.activ_addrs.add(data_ptr)
|
||||
self.activ_memory += x.untyped_storage().size()
|
||||
return x
|
||||
|
||||
return _pack_hook
|
||||
|
||||
def unpack_hook(self):
|
||||
def _unpack_hook(x):
|
||||
if self._is_activation(x):
|
||||
weight_grads = self._get_weight_grads_addrs()
|
||||
data_ptr = x.untyped_storage().data_ptr()
|
||||
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
|
||||
self.activ_addrs.remove(data_ptr)
|
||||
self.activ_memory -= x.untyped_storage().size()
|
||||
return x
|
||||
|
||||
return _unpack_hook
|
||||
|
||||
@contextmanager
|
||||
def record_activation(self):
|
||||
with torch.autograd.graph.saved_tensors_hooks(
|
||||
self.pack_hook(), self.unpack_hook()
|
||||
):
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
def get_weight_memory(model: torch.nn.Module):
|
||||
weights = [
|
||||
p.nelement() * p.element_size()
|
||||
for p in model.parameters()
|
||||
if p.device != "cpu"
|
||||
]
|
||||
return sum(weights)
|
||||
|
||||
@staticmethod
|
||||
def get_gradient_memory(model: torch.nn.Module):
|
||||
grads = [
|
||||
p.grad.nelement() * p.grad.element_size()
|
||||
for p in model.parameters()
|
||||
if p.grad is not None and p.grad.device != "cpu"
|
||||
]
|
||||
return sum(grads)
|
||||
|
||||
def _sum_activation_memory(self):
|
||||
return self.activ_memory
|
||||
|
||||
def get_optimizer_state_memory(self):
|
||||
if isinstance(self.optimizer, torch.optim.AdamW):
|
||||
params = sum(
|
||||
[
|
||||
p.nelement() * p.element_size()
|
||||
for pg in self.optimizer.param_groups
|
||||
for p in pg["params"]
|
||||
if torch.is_tensor(p) and p.device != "cpu"
|
||||
]
|
||||
)
|
||||
for state in self.optimizer.state.values():
|
||||
params += sum(
|
||||
[
|
||||
v.nelement() * v.element_size()
|
||||
for k, v in state.items()
|
||||
if torch.is_tensor(v) and v.device != "cpu"
|
||||
]
|
||||
)
|
||||
return params
|
||||
return 0
|
||||
|
||||
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
|
||||
return (
|
||||
self.get_weight_memory(self.model),
|
||||
self.get_gradient_memory(self.model),
|
||||
self._sum_activation_memory(),
|
||||
self.get_optimizer_state_memory(),
|
||||
)
|
||||
|
||||
def get_memory_usage_in_gb(self) -> str:
|
||||
w, g, a, opt = self._get_memory_usage()
|
||||
|
||||
return (
|
||||
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
|
||||
)
|
||||
|
||||
def get_memory_usage_in_mb(self) -> str:
|
||||
w, g, a, opt = self._get_memory_usage()
|
||||
|
||||
return (
|
||||
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
|
||||
)
|
||||
65
vacc_tools/parse_vacc_log_for_tracing.py
Normal file
65
vacc_tools/parse_vacc_log_for_tracing.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
log_tag = "LOG_TRACE:"
|
||||
tid_names = [
|
||||
(0, "module"),
|
||||
(1, "megatron"),
|
||||
(2, "deepspeed"),
|
||||
(10, "vacc-odsp"),
|
||||
(11, "vacc-dlc"),
|
||||
(12, "vacc-vccl"),
|
||||
(13, "vacc-cpu"),
|
||||
(14, "vacc-cpu_fallback"),
|
||||
(15, "vacc-ddr"),
|
||||
(20, "lib-vccl"),
|
||||
]
|
||||
|
||||
|
||||
def parse_files_of_process(args):
|
||||
pid, in_files = args
|
||||
out_file = "trace_" + pid + ".json"
|
||||
with open(out_file, "w", encoding="utf-8") as new_file:
|
||||
metadata_lines = [
|
||||
f'{{"name": "thread_name","ph": "M","pid": {pid},"tid": {tid},"args": {{"name": "{name}"}}}},'
|
||||
for tid, name in tid_names
|
||||
]
|
||||
new_file.write("[\n")
|
||||
new_file.write("\n".join(metadata_lines))
|
||||
new_file.write("\n")
|
||||
for file_path in in_files:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if log_tag in line:
|
||||
new_line = line.split(log_tag, 1)[1].strip()
|
||||
new_file.write(new_line + "\n")
|
||||
new_file.write("]")
|
||||
|
||||
|
||||
def parse_directory(directory):
|
||||
pro_files = defaultdict(list)
|
||||
for dirpath, dirnames, filenames in os.walk(directory):
|
||||
for filename in filenames:
|
||||
file_path = os.path.join(dirpath, filename)
|
||||
if filename.startswith("vacc") and os.path.getsize(file_path) != 0:
|
||||
pid = filename.rsplit("_", 1)[1].split(".")[0]
|
||||
pro_files[pid].append(file_path)
|
||||
|
||||
args = []
|
||||
for pid, in_files in pro_files.items():
|
||||
args.append((pid, in_files))
|
||||
|
||||
with Pool() as p:
|
||||
p.map(parse_files_of_process, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="parse vacc log files and generate trace files"
|
||||
)
|
||||
parser.add_argument("directory", type=str, help="log directory to parse")
|
||||
args = parser.parse_args()
|
||||
parse_directory(args.directory)
|
||||
329
vacc_tools/trace_logger.py
Normal file
329
vacc_tools/trace_logger.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
This module provides mechanisms for tracing torch's module and function's execution,
|
||||
and output the trace into a json file.
|
||||
|
||||
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
|
||||
If not, no trace will be applied.
|
||||
|
||||
Inside your module, create your module's tracer functions by using `get_trace_api`.
|
||||
You will get three functions:
|
||||
* `@trace_time(name)`: decorator to trace the execution of a function.
|
||||
```python
|
||||
@trace_time("my_func")
|
||||
def my_func(x):
|
||||
...
|
||||
```
|
||||
* `@trace_autograd_function()`: decorator to trace the execution of forward
|
||||
and backward of a user defined `torch.autograd.Function` operator.
|
||||
```python
|
||||
@trace_autograd_function()
|
||||
class MyAutogradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
...
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
...
|
||||
```
|
||||
* `register_module_trace()`: function to register trace a model (`nn.Module`),
|
||||
it applies traces recursively to a torch model by enumerating all nn.Module
|
||||
and register tracer to their forward and backward function. Only applying to
|
||||
top level nn.Module is recommended.
|
||||
```python
|
||||
model = Model()
|
||||
register_module_trace(model)
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceEntry:
|
||||
name: str
|
||||
cat: str
|
||||
pid: int
|
||||
tid: int
|
||||
ts: int
|
||||
ph: str
|
||||
args: str = None
|
||||
|
||||
def to_json_str(self):
|
||||
d = asdict(self)
|
||||
if self.args is None:
|
||||
d.pop("args")
|
||||
return json.dumps(d, separators=(",", ": "))
|
||||
|
||||
|
||||
class LogFiles:
|
||||
def __init__(self) -> None:
|
||||
self.loggers = {}
|
||||
|
||||
def get(self, file_prefix, rank, pid):
|
||||
os.makedirs("log", exist_ok=True)
|
||||
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
|
||||
if not fpath in self.loggers:
|
||||
self.loggers[fpath] = open(fpath, "w")
|
||||
return self.loggers[fpath]
|
||||
|
||||
def close(self):
|
||||
for f in self.loggers.values():
|
||||
f.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
def trace_logger_enabled() -> bool:
|
||||
return (
|
||||
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
|
||||
)
|
||||
|
||||
|
||||
class TraceLogger:
|
||||
_log_files = LogFiles()
|
||||
|
||||
def __init__(self, category, tid=None, file_prefix=None) -> None:
|
||||
self.enabled = trace_logger_enabled()
|
||||
|
||||
if self.enabled:
|
||||
self.pid = os.getpid()
|
||||
self.logger = None
|
||||
self.cat = category
|
||||
self._traces = {}
|
||||
self.global_rank = 0
|
||||
if tid is None:
|
||||
self.tid = MODULE_TID.get(category, 1000)
|
||||
else:
|
||||
self.tid = tid
|
||||
self.file_prefix = file_prefix if file_prefix is not None else self.cat
|
||||
|
||||
self.registered_modules = []
|
||||
|
||||
def _creat_logger(self) -> None:
|
||||
# delay creating logger file until first log call,
|
||||
# since torch.distributed may not be ready yet
|
||||
if torch.distributed.is_initialized():
|
||||
self.global_rank = torch.distributed.get_rank()
|
||||
self.logger = TraceLogger._log_files.get(
|
||||
self.file_prefix, self.global_rank, self.pid
|
||||
)
|
||||
|
||||
def begin_trace(self, name, memory=False) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if self.logger is None:
|
||||
self._creat_logger()
|
||||
assert self.logger is not None
|
||||
|
||||
name = f"{name}" # convert it to str to ensure json serializable
|
||||
|
||||
start_time = int(datetime.now().timestamp() * 1e6) # in us
|
||||
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
|
||||
|
||||
mem_trace = self._get_memory(start_time) if memory else None
|
||||
|
||||
if name not in self._traces:
|
||||
self._traces[name] = [(trace, mem_trace)]
|
||||
else: # in case call to the function is nested
|
||||
self._traces[name].append((trace, mem_trace))
|
||||
|
||||
def end_trace(self, name, flush=False, memory=False) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
name = f"{name}" # convert it to str to ensure json serializable
|
||||
|
||||
assert self.logger is not None, "begin_trace should be called before end_trace"
|
||||
assert name in self._traces, "begin_trace should be called before end_trace"
|
||||
|
||||
start_trace, start_mem = self._traces[name].pop()
|
||||
if start_mem is not None:
|
||||
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
|
||||
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
|
||||
|
||||
end_time = int(datetime.now().timestamp() * 1e6) # in us
|
||||
args = {"value(us)": end_time - start_trace.ts}
|
||||
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
|
||||
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
|
||||
|
||||
if memory:
|
||||
mem_trace = self._get_memory(end_time)
|
||||
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
|
||||
|
||||
if flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
if self.logger is not None:
|
||||
self.logger.flush()
|
||||
|
||||
def _get_memory(self, timestamp):
|
||||
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
|
||||
mem_trace = TraceEntry(
|
||||
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
|
||||
)
|
||||
return mem_trace
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _trace_time(name, logger_inst, memory=False, flush=False):
|
||||
if not logger_inst.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
logger_inst.begin_trace(name)
|
||||
yield
|
||||
logger_inst.end_trace(name, flush=flush)
|
||||
|
||||
|
||||
SKIPED_MODULES = []
|
||||
|
||||
|
||||
def _register_module_trace(
|
||||
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
|
||||
):
|
||||
if not logger_inst.enabled:
|
||||
return
|
||||
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
return
|
||||
|
||||
def _register(m):
|
||||
module_name = f"{type(m).__name__}"
|
||||
if module_name == "WrapName":
|
||||
module_name = f"{type(m.forward_func.__self__).__name__}"
|
||||
|
||||
if module_name in SKIPED_MODULES:
|
||||
return
|
||||
|
||||
forward_name = module_name + ".forward"
|
||||
|
||||
m.register_forward_pre_hook(
|
||||
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
|
||||
)
|
||||
m.register_forward_hook(
|
||||
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
|
||||
)
|
||||
|
||||
if not forward_only:
|
||||
backward_name = module_name + ".backward"
|
||||
|
||||
m.register_full_backward_pre_hook(
|
||||
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
|
||||
)
|
||||
m.register_full_backward_hook(
|
||||
lambda m, grad_in, grad_out: logger_inst.end_trace(
|
||||
backward_name, memory=True, flush=flush
|
||||
)
|
||||
)
|
||||
|
||||
for m in module.modules():
|
||||
if m in logger_inst.registered_modules:
|
||||
print(
|
||||
f"module `{m}` already registered, skip applying trace on same module multiple times."
|
||||
)
|
||||
continue
|
||||
_register(m)
|
||||
|
||||
|
||||
def _trace_autograd_function(logger_inst):
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, torch.autograd.Function):
|
||||
return cls
|
||||
|
||||
def _apply(name, method):
|
||||
def wrapper(*args, **kwargs):
|
||||
with _trace_time(name, logger_inst=logger_inst, memory=True):
|
||||
result = method(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
for attr in ["forward", "backward"]:
|
||||
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _register_optimizer_trace(
|
||||
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
|
||||
):
|
||||
if not logger_inst.enabled:
|
||||
return
|
||||
|
||||
trace_name = f"{type(optimizer).__name__}.step"
|
||||
|
||||
if isinstance(optimizer, torch.optim.Optimizer):
|
||||
optimizer.register_step_pre_hook(
|
||||
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
|
||||
)
|
||||
optimizer.register_step_post_hook(
|
||||
lambda m, *args, **kwargs: logger_inst.end_trace(
|
||||
trace_name, memory=True, flush=flush
|
||||
)
|
||||
)
|
||||
elif hasattr(optimizer, "step") and callable(optimizer.step):
|
||||
# customized optimzier does not has step hooks
|
||||
|
||||
original_step = optimizer.step
|
||||
|
||||
def traced_step(*args, **kwargs):
|
||||
logger_inst.begin_trace(trace_name, memory=True)
|
||||
result = original_step(*args, **kwargs)
|
||||
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
||||
return result
|
||||
|
||||
# Replace the step method with the new function
|
||||
optimizer.step = traced_step
|
||||
else:
|
||||
# unknown optimizer or wrong instance pass to this function.
|
||||
pass
|
||||
|
||||
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
|
||||
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
|
||||
original_reduce = optimizer.reduce_gradients
|
||||
|
||||
def traced_reduce(*args, **kwargs):
|
||||
logger_inst.begin_trace(trace_name, memory=True)
|
||||
result = original_reduce(*args, **kwargs)
|
||||
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
||||
return result
|
||||
|
||||
# Replace the step method with the new function
|
||||
optimizer.reduce_gradients = traced_reduce
|
||||
|
||||
|
||||
def get_trace_api(name="nn.Module"):
|
||||
"""generate module execution trace APIs for a given module name
|
||||
|
||||
Args:
|
||||
name (str): module name
|
||||
|
||||
Returns:
|
||||
tuple: (trace_time, register_module_trace, trace_autograd_function)
|
||||
Usage of these three functions is describted in the docstring of this module
|
||||
"""
|
||||
_trace_logger = TraceLogger(name)
|
||||
|
||||
return (
|
||||
partial(_trace_time, logger_inst=_trace_logger),
|
||||
partial(_register_module_trace, logger_inst=_trace_logger),
|
||||
partial(_trace_autograd_function, logger_inst=_trace_logger),
|
||||
partial(_register_optimizer_trace, logger_inst=_trace_logger),
|
||||
)
|
||||
Reference in New Issue
Block a user