Files
enginex-vastai-va16-vllm/vacc_tools/generate_trace.py
2026-04-02 04:55:00 +00:00

215 lines
7.5 KiB
Python

"""Generating tracing json files from log files.
Usage:
python -m vacc_tools.generate_trace --log-dir <directory of log files> --out-file-prefix <prefix of output file>
"""
import argparse
import json
import os
import re
import numpy as np
import tabulate
from glob import glob
from collections import defaultdict
from multiprocessing import Pool
def run_stats_on_traces(timelines):
op_cat_list = ["ODSP", "DLC", "VCCL", "CPU", "CPU_OP"]
op_stats = {op: {} for op in op_cat_list}
for line in timelines:
if '"E"' not in line: # optim 3, skip everything if not `"E"`
continue
# optim 2: using `[:-2]` instead of replace()
line = line[:-2] # remove ',\n'
try:
values = json.loads(line)
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
continue
if values["ph"] == "E" and values["cat"] in op_cat_list:
cat = values["cat"]
if values["name"] not in op_stats[cat]:
op_stats[cat][values["name"]] = []
if "dur" in values["args"]:
# optim 1: using `[:-2]` instead of replace()
op_stats[cat][values["name"]].append(
int(values["args"]["dur"][:-2]) # strip `us`
)
elif "values(us)" in values["args"]:
op_stats[cat][values["name"]].append(values["args"]["value(us)"])
op_tables = {}
for cat, stats in op_stats.items():
# optim 4: using list comprehension instead of for loop
table = []
for name, dur in stats.items():
dur = np.array(dur)
t = [
name,
np.min(dur),
np.max(dur),
np.sum(dur),
np.mean(dur),
np.percentile(dur, 90),
len(dur),
]
table.append(t)
table = sorted(table, key=lambda x: x[-1], reverse=True)
op_tables[cat] = tabulate.tabulate(
table,
headers=["op", "min", "max", "sum", "avg", "p90", "count"],
tablefmt="plain",
)
if cat in ["VCCL", "ODSP", "DLC"]:
op_tables["VACC-ALL"] = op_tables.get("VACC-ALL", []) + [
t + [cat] for t in table
]
total = sum([x[3] for x in op_tables["VACC-ALL"]])
op_tables["VACC-ALL"] = [t + [t[3] / total * 100] for t in op_tables["VACC-ALL"]]
op_tables["VACC-ALL"] = tabulate.tabulate(
sorted(op_tables["VACC-ALL"], key=lambda x: x[-1], reverse=True),
headers=["op", "min", "max", "sum", "avg", "p90", "count", "cat", "percent(%)"],
tablefmt="plain",
)
return op_tables
def get_rank_info(files):
# using pattern rank-<rank> in file name to get rank
for fpath in files:
rank = re.findall(r"rank-(\d+)", fpath)
if rank:
return int(rank[0])
return 0
def extract_traces(arg):
files, target_file_path, group_name, trace_token = arg
entries = [
(0, "scheduler"),
(1, "megatron"),
(2, "deepspeed"),
(3, "nn.Module"),
(10, "vacc-odsp"),
(11, "vacc-dlc"),
(12, "vacc-vccl"),
(13, "vacc-cpu"),
(14, "vacc-fallback"),
(15, "vacc-ddr"),
(20, "lib-vccl"),
]
with open(target_file_path, "w", encoding="utf-8") as trace_file:
trace_file.write("[")
for tid, thread_name in entries:
line = f'{{"cat":"__metadata","pid":{group_name},"tid":{tid},"ts":0,"ph":"M","name":"thread_name","args":{{"name":"{thread_name}"}}}},\n'
trace_file.write(line)
timelines = []
for fpath in files:
with open(fpath, "r", encoding="utf-8") as file:
# timelines += [line.split(trace_token)[1] for line in file if trace_token in line]
for line in file:
if trace_token in line:
# 找到目标字符串,取其之后的内容(包括目标字符串)
timelines.append(line.split(trace_token)[1])
try:
json.loads(timelines[-1][:-2]) # remove ',\n'
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
# chrome:://tracing stops reading following lines if an error encountered
# so must remove lines with error
timelines.pop()
for line in timelines[:-1]:
trace_file.write(line)
# fixing JSON format error by removing last comma in a list
trace_file.write(timelines[-1].replace(",\n", "\n"))
trace_file.write("]")
op_stats = run_stats_on_traces(timelines)
with open(
target_file_path.replace(".json", ".txt"), "w", encoding="utf-8"
) as op_stats_file:
for cat, tables in op_stats.items():
op_stats_file.write(f"{cat}".center(80, "-") + "\n")
op_stats_file.write(tables + "\n\n")
def merge_schedule(out_file_prefix):
scheduler_data = []
for file in glob(f"{out_file_prefix}*.json"):
if file.endswith("schedule.json"):
continue
assert "rank" in file
rank = file.split("rank_")[-1].split("_")[0]
pid = None
with open(file, "r", encoding="utf-8") as f:
for line in f:
# set all schedule's pid to 0 and set all schedule's tid to rank id
if '"tid":0,' in line and "__metadata" not in line:
if pid is None:
pid = line.split('"pid":')[1].split(",")[0]
line = line.replace(f'"pid":{pid}', f'"pid":0')
line = line.replace('"tid":0,', f'"tid":{rank},')
scheduler_data.append(line)
out_file = f"{out_file_prefix}schedule.json"
with open(out_file, "w", encoding="utf-8") as f:
f.write("[\n")
f.writelines(scheduler_data[:-1])
f.write(scheduler_data[-1].replace(",\n", "\n"))
f.write("]\n")
def scan_and_generate_trace(args, trace_token):
grouped_files = defaultdict(list)
for root, dirs, files in os.walk(args.log_dir):
for filename in files:
fpath = os.path.join(root, filename)
file_size = os.path.getsize(fpath)
if file_size != 0:
group_name = filename.rsplit("_", 1)[1].split(".")[0]
grouped_files[group_name].append(fpath)
pool_args = []
for group_name, files in grouped_files.items():
rank = get_rank_info(files)
out_file = f"{args.out_file_prefix}rank_{rank}_{group_name}.json"
pool_args.append((files, out_file, group_name, trace_token))
with Pool(len(grouped_files)) as p:
p.map(extract_traces, pool_args)
if args.merge_schedule:
merge_schedule(args.out_file_prefix)
if __name__ == "__main__":
TRACE_TOKEN = "LOG_TRACE:"
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(os.path.dirname(current_file_path))
find_directory = os.path.join(parent_directory, "log")
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-dir", default=find_directory, type=str, help="directory of log files"
)
parser.add_argument("--out-file-prefix", default="timeline_", type=str)
parser.add_argument("--merge-schedule", action="store_true")
args = parser.parse_args()
scan_and_generate_trace(args, TRACE_TOKEN)
print("Scan and trace generation done!")