Sync from v0.13
This commit is contained in:
174
tools/profiler/nsys_profile_tools/README.md
Normal file
174
tools/profiler/nsys_profile_tools/README.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# gputrc2graph.py
|
||||
|
||||
This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files
|
||||
(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level
|
||||
summaries and visualizations of GPU and non-GPU time. It is useful for
|
||||
profiling and analyzing nsys profile output.
|
||||
|
||||
## Usage
|
||||
|
||||
### Command-line Arguments
|
||||
|
||||
- `--in_file`
|
||||
**(required)**
|
||||
List of input files and their metadata. Each entry should be in the format:
|
||||
`<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>`
|
||||
- `nsys-rep`: Path to the `.nsys-rep` file.
|
||||
- `engine`: Engine name (e.g., `vllm`).
|
||||
- `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`).
|
||||
- `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without
|
||||
profiling. Specify `0` to use the elapsed time from the nsys-rep file
|
||||
(this may inflate non-GPU time if actual runtime without profiling is
|
||||
less). Multiple entries can be provided, separated by spaces.
|
||||
|
||||
- `--out_dir`
|
||||
Output directory for the generated CSV and HTML files.
|
||||
If not specified, results are saved in the current directory.
|
||||
|
||||
- `--title`
|
||||
Title for the HTML chart/visualization.
|
||||
|
||||
- `--nsys_cmd`
|
||||
Path to the `nsys` command.
|
||||
Default: `nsys` (assumes it is in your PATH).
|
||||
Use this if `nsys` is not in your system PATH.
|
||||
|
||||
## Notes
|
||||
|
||||
- Make sure you have pandas installed.
|
||||
- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH.
|
||||
- For more details on available engines and models, see the help string in
|
||||
the script or run:
|
||||
|
||||
```bash
|
||||
python3 gputrc2graph.py --help
|
||||
```
|
||||
|
||||
## Example 1: analyze a single profile
|
||||
|
||||
To analyze the GPU cycles for say, gpt-oss model with vLLM engine:
|
||||
|
||||
1. Run the following command to collect nsys profile, for vllm serve config.
|
||||
|
||||
```bash
|
||||
nsys profile -t cuda -o run1 -f true --trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \
|
||||
vllm serve openai/gpt-oss-120b ...
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
- DELAY: how many seconds to delay nsys from collecting profiles, needed so
|
||||
that profiles aren't captured till vllm server has come up and load
|
||||
generation starts.
|
||||
- DURATION: how many seconds for nsys profile to run before generating the
|
||||
profile. This should be > the duration of the run.
|
||||
|
||||
2. Run again, this time without collecting the profile, and get the total run
|
||||
time in seconds. This value will be used by the script to calculate the
|
||||
CPU(non-GPU) seconds for the analysis.
|
||||
|
||||
3. Say the run elapsed time is 306 seconds, from step #2. Run script to
|
||||
analyze:
|
||||
|
||||
```bash
|
||||
python3 gputrc2graph.py \
|
||||
--in_file run1.nsys-rep,vllm,gpt-oss,306 \
|
||||
--title "vLLM-gpt-oss profile"
|
||||
```
|
||||
|
||||
The command will produce 2 files for analysis:
|
||||
|
||||
- result.html: this categorizes kernel names into different categories in a
|
||||
stacked bar chart.
|
||||
- result.csv: shows how the kernel names are mapped to the different
|
||||
categories.
|
||||
|
||||
### HTML visualization with result.html
|
||||
|
||||
The html file shows the number of elapsed seconds due to different GPU
|
||||
Substages or categories, which consist of moe_gemm (Mixture of Experts GEMM)
|
||||
kernels the biggest category, at 148 seconds, followed by "attn" or attention
|
||||
kernels. This lets the user prioritize the kernels to focus on for performance
|
||||
optimizations.
|
||||
|
||||

|
||||
|
||||
There's also an appended data table underneath the bar chart for copying out to other post-processing tools.
|
||||
|
||||

|
||||
|
||||
### Kernel to category mapping with result.csv
|
||||
|
||||
Suppose the user would like to focus on improving triton kernels. It's not the
|
||||
biggest consumer of cycles at 9.74 sec but perhaps it hasn't been optimized.
|
||||
The next step is to use the result.csv to dive into what the kernels are which
|
||||
compose the triton kernel GPU cycles. The following image shows that
|
||||
triton_poi_fused__to_copy_add_addmm_cat_.. kernel to be the biggest
|
||||
contributor to GPU cycles.
|
||||
|
||||

|
||||
|
||||
## Example 2: analyze multiple profiles
|
||||
|
||||
Suppose the user has multiple nsys trace files, captured for different models,
|
||||
say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU
|
||||
time, something like the following command can be used.
|
||||
|
||||
```bash
|
||||
python3 gputrc2graph.py \
|
||||
--in_file run1.nsys-rep,vllm,llama,100 run2.nsys-rep,vllm,gpt-oss,102 \
|
||||
--out_dir results \
|
||||
--title "Comparison of vLLM Models"
|
||||
```
|
||||
|
||||
The analysis process is similar to example 1 but now there will be multiple
|
||||
stack bar charts that can be compared. The categories for the different
|
||||
kernels will remain the same, so that it's easy to compare the GPU cycles for
|
||||
the same categories.
|
||||
|
||||
Once a category is shown to have more cycles for one configuration than
|
||||
another, the next step would be to use the csv file to see what kernels are
|
||||
mapped into that category, and which kernels are taking the largest amount of
|
||||
time which would cause a difference for the overall category.
|
||||
|
||||
## Example 3: add new classification for a new model
|
||||
|
||||
To create a new engine DEF with model ABC, just add another json file in the same directory as
|
||||
gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications.
|
||||
|
||||
Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels
|
||||
have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*"
|
||||
or "*K*" in them, just add another .json file in the same directory as
|
||||
gputrc2graph.py with the same format as the other json files, like the following:
|
||||
|
||||
```json
|
||||
{
|
||||
"DEF": {
|
||||
"ABC": {
|
||||
"H|I": "gemm",
|
||||
"J|K": "attn",
|
||||
"CUDA mem": "non-gpu-H_D_memops",
|
||||
".*": "misc"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Each entry in the dictionary consists of:
|
||||
|
||||
- key: a regex used to classify the kernels
|
||||
- value: the category to classify the kernels into.
|
||||
|
||||
The last 2 entries are common for all engine/models, consisting of CUDA memory
|
||||
operations and a 'misc' for anything that's leftover and can't be classified.
|
||||
|
||||
When invoking gputrc2graph.py, specify a trace file with this new model/engine
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
--infile new.nsys-rep,DEF,ABC,<runtime>
|
||||
```
|
||||
|
||||
If the engine_DEF.json file already exists, just add the model as a new node in
|
||||
the existing engine file, after the other models.
|
||||
344
tools/profiler/nsys_profile_tools/gputrc2graph.py
Executable file
344
tools/profiler/nsys_profile_tools/gputrc2graph.py
Executable file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This generates gpu kernel analysis output from nsys rep. Will call nsys
|
||||
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
|
||||
csv and html output for analysis
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import regex as re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# helper data class for annotating kernels
|
||||
def load_engine_model():
|
||||
"""returns engine_model built from all json files in the current dir"""
|
||||
import glob
|
||||
import json
|
||||
|
||||
engine_model = {}
|
||||
|
||||
json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json"))
|
||||
for fname in json_files:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
engine_model.update(json.load(f))
|
||||
return engine_model
|
||||
|
||||
|
||||
class GPUTrace2Graph:
|
||||
"""
|
||||
Parses output of nsys report, generates csv and bar chart output
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
import pandas as pd # avoid importing till needed
|
||||
|
||||
self.pd = pd
|
||||
self.pd.options.mode.copy_on_write = True
|
||||
|
||||
# helper functions for generating trace->summary csvs
|
||||
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
|
||||
logger.info("loading %s", in_file)
|
||||
df = self.pd.read_csv(
|
||||
in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"]
|
||||
)
|
||||
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
|
||||
df = self.sum_non_overlapping_intervals(df)
|
||||
# get ready to print table with elapsed times per kernel
|
||||
df["Instances"] = 1
|
||||
df_sum = df.groupby("Name", as_index=False).agg(
|
||||
{"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"}
|
||||
)
|
||||
|
||||
# generate csv
|
||||
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
|
||||
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
|
||||
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
|
||||
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv(
|
||||
out_file, index=False
|
||||
)
|
||||
|
||||
def sum_non_overlapping_intervals(self, df):
|
||||
"""
|
||||
returns new sorted df with Elapsed Time (ns) column using
|
||||
vectorized operations
|
||||
"""
|
||||
logger.info("sorting %s trace records by start time", str(df.shape))
|
||||
|
||||
# Sort by start time and reset index
|
||||
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
|
||||
|
||||
# Initialize elapsed time as duration
|
||||
df["Elapsed Time (ns)"] = df["Duration (ns)"]
|
||||
|
||||
# Get numpy arrays for faster operations
|
||||
starts = df["Start (ns)"].values
|
||||
ends = df["End (ns)"].values
|
||||
|
||||
# Keep track of current interval end
|
||||
current_end = ends[0]
|
||||
display_units = int(len(df) / 100)
|
||||
# Update current_end for overlapping intervals
|
||||
for i in range(1, len(df)):
|
||||
if i % display_units == 0:
|
||||
print(f"processing trace: {int(i / len(df) * 100)} %", end="\r")
|
||||
if starts[i] <= current_end:
|
||||
if ends[i] > current_end:
|
||||
# Partial overlap
|
||||
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
|
||||
ends[i] - current_end
|
||||
)
|
||||
current_end = ends[i]
|
||||
else:
|
||||
# Complete overlap
|
||||
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
|
||||
else:
|
||||
# No overlap
|
||||
current_end = ends[i]
|
||||
|
||||
return df
|
||||
|
||||
# functions for generating html files
|
||||
def make_html(self, df, output_dir, title):
|
||||
"""make html graph from df"""
|
||||
import plotly.express as px
|
||||
|
||||
if df.empty:
|
||||
return
|
||||
output_name = output_dir + "/result"
|
||||
if not title:
|
||||
title = "Model_Engine"
|
||||
x = "Model_Engine"
|
||||
y = "Elapsed Time (sec)"
|
||||
color = "Category"
|
||||
""" generate kernel mapping table """
|
||||
# Sort Model_Engine categories by last field after underscore
|
||||
df["Model_Engine"] = self.pd.Categorical(
|
||||
df["Model_Engine"],
|
||||
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
|
||||
)
|
||||
df[["Model_Engine", color, "Instances", "Name", y]].sort_values(
|
||||
by=color
|
||||
).to_csv(f"{output_name}.csv", index=False)
|
||||
graph = px.histogram(
|
||||
df.round(2),
|
||||
x=x,
|
||||
y=y,
|
||||
title=(f"{y} for {title}"),
|
||||
color=color,
|
||||
text_auto=True,
|
||||
)
|
||||
# wrap x axis labels
|
||||
graph.update_xaxes(automargin=True)
|
||||
graph.write_html(f"{output_name}.html")
|
||||
"""
|
||||
Generate data table with columns per Model_Engine into result.html
|
||||
"""
|
||||
pivot_df = df.pivot_table(
|
||||
values="Elapsed Time (sec)",
|
||||
index="Category",
|
||||
columns="Model_Engine",
|
||||
aggfunc="sum",
|
||||
observed=False,
|
||||
).round(2)
|
||||
# Add sum row at bottom
|
||||
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
|
||||
pivot_df.fillna("").to_html("temp.html")
|
||||
with (
|
||||
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
|
||||
open("temp.html", encoding="utf-8") as infile,
|
||||
):
|
||||
outfile.write(infile.read())
|
||||
os.remove("temp.html")
|
||||
|
||||
print(
|
||||
f"Finished generating: \n"
|
||||
f" {output_name}.html for stack bar chart \n"
|
||||
f" {output_name}.csv for Kernel-Category mapping"
|
||||
)
|
||||
|
||||
def anno_gpu_kernname(self, df, mapping):
|
||||
"""add "Category" column"""
|
||||
|
||||
def anno_gpu_kernname_helper(name):
|
||||
for kern_name, val in mapping.items():
|
||||
if re.search(kern_name, name):
|
||||
return val
|
||||
|
||||
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
|
||||
|
||||
def make_nongpu_row(self, df, nongpu_sec):
|
||||
"""this will append non-gpu time entry at end of df"""
|
||||
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
|
||||
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
|
||||
nongpu_row["Instances"] = 1
|
||||
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
|
||||
return nongpu_row
|
||||
|
||||
def is_valid_file(self, base_file):
|
||||
"""asserts if base_file is non-existent or is empty"""
|
||||
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, (
|
||||
f"{base_file} doesn't exist or is empty"
|
||||
)
|
||||
|
||||
def should_gen_file(self, new_file, base_file):
|
||||
"""figure out if new file should be generated from base_file"""
|
||||
self.is_valid_file(base_file)
|
||||
if (
|
||||
os.path.exists(new_file)
|
||||
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
|
||||
and (os.path.getsize(base_file) > 0)
|
||||
):
|
||||
logger.info("reusing %s", new_file)
|
||||
return False
|
||||
else:
|
||||
logger.info("generating %s", new_file)
|
||||
return True
|
||||
|
||||
def gen_sum_file(self, file, nsys_cmd):
|
||||
"""
|
||||
generates sum file from nsys trace with times per kernel and
|
||||
returns the name of the sum file
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
file_dir = os.path.dirname(file)
|
||||
file_name = os.path.basename(file)
|
||||
|
||||
if not file_dir:
|
||||
file_dir = "."
|
||||
# Walk through trace and get the total non-overlapped time
|
||||
nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv"
|
||||
sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv"
|
||||
if self.should_gen_file(nsys_stats_file, file):
|
||||
cmd = [
|
||||
nsys_cmd,
|
||||
"stats",
|
||||
"-r",
|
||||
"cuda_gpu_trace",
|
||||
file,
|
||||
"-o",
|
||||
f"{file_dir}/{file_name}",
|
||||
]
|
||||
cmd_str = " ".join(cmd)
|
||||
logger.info("+ %s", cmd_str)
|
||||
# estimate time based on calibrated 240M/min
|
||||
file_size_mb = os.path.getsize(file) / 1e6
|
||||
logger.info(
|
||||
"nsys stats for %.2f MB file expected to take %.2f min",
|
||||
file_size_mb,
|
||||
file_size_mb / 240,
|
||||
)
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception:
|
||||
logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str)
|
||||
exit(1)
|
||||
logger.info("generating non-overalapped sum %s", sum_file)
|
||||
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
|
||||
self.is_valid_file(sum_file)
|
||||
logger.info("Finished generating %s", sum_file)
|
||||
return sum_file
|
||||
|
||||
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
|
||||
"""generates graph and csv file from in_file into out_dir"""
|
||||
# Initialize an empty DataFrame to store combined data
|
||||
combined_df = self.pd.DataFrame()
|
||||
for idx, (file, engine, model, total_sec) in enumerate(in_file):
|
||||
file_dir = os.path.dirname(file)
|
||||
file_name = os.path.basename(file)
|
||||
if not file_dir:
|
||||
file_dir = "."
|
||||
sum_file = self.gen_sum_file(file, nsys_cmd)
|
||||
# read kernel summary file
|
||||
df = self.pd.read_csv(sum_file)
|
||||
# annotate kernel to their categories
|
||||
assert engine_model.get(engine), f"engine {engine} unknown"
|
||||
assert engine_model[engine].get(model), f"model {model} unknown"
|
||||
# remove nsys-rep from file_name for shorter x-label
|
||||
file_name = file_name.replace(".nsys-rep", "")
|
||||
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
|
||||
self.anno_gpu_kernname(df, engine_model[engine][model])
|
||||
# patch in non-gpu time
|
||||
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
|
||||
total_sec = round(float(total_sec), 1)
|
||||
if total_sec < gpu_sec:
|
||||
logger.warning(
|
||||
"Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ",
|
||||
total_sec,
|
||||
gpu_sec,
|
||||
)
|
||||
total_sec = gpu_sec
|
||||
nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec)
|
||||
df = self.pd.concat([df, nongpu_row], ignore_index=True)
|
||||
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
|
||||
if out_dir is None:
|
||||
out_dir = "."
|
||||
else:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
# generate html file
|
||||
self.make_html(combined_df, out_dir, title)
|
||||
|
||||
|
||||
def parse_tuple(s):
|
||||
return tuple(s.split(","))
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(
|
||||
format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO
|
||||
)
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Process nsys rep and generate kernel non-overlapped cycles. \n"
|
||||
"Example:\n"
|
||||
"gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n"
|
||||
"d2.nsys-rep,vllm,gpt-oss,102 "
|
||||
'--out_dir results/ --title "Model=gpt-oss vLLM chart"'
|
||||
),
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
# load supported engine_model
|
||||
engine_model_supported = load_engine_model()
|
||||
# Get a string representation of supported engine/model combinations
|
||||
engine_model_supported_str = ", ".join(
|
||||
f"{engine}:[{', '.join(models.keys())}]"
|
||||
for engine, models in engine_model_supported.items()
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in_file",
|
||||
type=parse_tuple,
|
||||
nargs="+",
|
||||
help=(
|
||||
"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
|
||||
"separated by space. Elapsed_nonprofiled_sec is runtime without "
|
||||
"profiling used to calculate non-gpu time. Specify 0 to use "
|
||||
"elapsed time from nsys-rep but that might inflate non-gpu time. "
|
||||
f"Available engine:[model] are: {engine_model_supported_str} "
|
||||
f"Example: --infile d1.nsys-rep,vllm,llama,100 "
|
||||
"d2.nsys-rep,vllm,gpt-oss,102"
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
|
||||
parser.add_argument("--title", help=("title for html chart"))
|
||||
parser.add_argument(
|
||||
"--nsys_cmd",
|
||||
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
|
||||
default="nsys",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
gputrace = GPUTrace2Graph()
|
||||
gputrace.gen_graph(
|
||||
args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
tools/profiler/nsys_profile_tools/images/csv1.png
Normal file
BIN
tools/profiler/nsys_profile_tools/images/csv1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 145 KiB |
BIN
tools/profiler/nsys_profile_tools/images/html.png
Normal file
BIN
tools/profiler/nsys_profile_tools/images/html.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 70 KiB |
BIN
tools/profiler/nsys_profile_tools/images/html_tbl.png
Normal file
BIN
tools/profiler/nsys_profile_tools/images/html_tbl.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
63
tools/profiler/nsys_profile_tools/vllm_engine_model.json
Normal file
63
tools/profiler/nsys_profile_tools/vllm_engine_model.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"vllm": {
|
||||
"llama": {
|
||||
"fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm",
|
||||
"gemm|nvjet": "gemm",
|
||||
"moe|sigmoid": "moe",
|
||||
"CatArrayBatched|prepare_inputs": "prepare_next",
|
||||
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
|
||||
"_norm_|Norm": "norm",
|
||||
"act_and_mul_": "activation",
|
||||
"Rotary": "rope",
|
||||
"SoftMax": "softmax",
|
||||
"flash|fmha": "attn",
|
||||
"elementwise": "elementwise",
|
||||
"fp8_quant|cvt_": "quantize",
|
||||
"reduce_kernel": "reduce",
|
||||
"triton": "triton_kernel",
|
||||
"CUDA mem": "non-gpu-H_D_memops",
|
||||
".*": "misc"
|
||||
},
|
||||
"ds": {
|
||||
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
|
||||
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm",
|
||||
"gemm|matmul|nvjet": "gemm",
|
||||
"moe|sigmoid|expert": "moe",
|
||||
"CatArrayBatched": "prepare_next",
|
||||
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
|
||||
"Norm|_norm_": "norm",
|
||||
"sbtopk": "topk",
|
||||
"act_and_mul_": "activation",
|
||||
"compute_position_kernel": "rope",
|
||||
"elementwise": "elementwise",
|
||||
"fp8_quant|quant_fp8|cvt_": "quantize",
|
||||
"reduce": "reduce",
|
||||
"SoftMax": "softmax",
|
||||
"_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn",
|
||||
"triton": "triton_kernel",
|
||||
"topk": "topk",
|
||||
"CUDA mem": "non-gpu-H_D_memops",
|
||||
".*": "misc"
|
||||
},
|
||||
"gpt-oss": {
|
||||
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
|
||||
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm",
|
||||
"gemm|matmul|nvjet": "gemm",
|
||||
"moe|sigmoid|expert|splitKreduce": "moe",
|
||||
"CatArrayBatched": "prepare_next",
|
||||
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
|
||||
"Norm|_norm_": "norm",
|
||||
"topk": "topk",
|
||||
"act_and_mul_": "activation",
|
||||
"compute_position_kernel": "rope",
|
||||
"elementwise": "elementwise",
|
||||
"fp8_quant|quant_fp8|cvt_|quantize": "quantize",
|
||||
"reduce": "reduce",
|
||||
"SoftMax": "softmax",
|
||||
"_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn",
|
||||
"triton": "triton_kernel",
|
||||
"CUDA mem": "non-gpu-H_D_memops",
|
||||
".*": "misc"
|
||||
}
|
||||
}
|
||||
}
|
||||
87
tools/profiler/print_layerwise_table.py
Normal file
87
tools/profiler/print_layerwise_table.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
|
||||
from vllm.profiler.utils import TablePrinter, indent_string
|
||||
|
||||
|
||||
def flatten_entries(entry_cls, profile_dict: dict):
|
||||
entries_and_depth = []
|
||||
|
||||
def get_entries(node, curr_depth=0):
|
||||
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
|
||||
|
||||
for child in node["children"]:
|
||||
get_entries(
|
||||
child,
|
||||
curr_depth=curr_depth + 1,
|
||||
)
|
||||
|
||||
for root in profile_dict:
|
||||
get_entries(root)
|
||||
|
||||
return entries_and_depth
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by examples/offline_inference/profiling.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The phase to print the table for. This is either"
|
||||
"prefill or decode_n, where n is the decode step "
|
||||
"number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--table",
|
||||
type=str,
|
||||
choices=["summary", "model"],
|
||||
default="summary",
|
||||
help="Which table to print, the summary table or the layerwise model table",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.json_trace) as f:
|
||||
profile_data = json.load(f)
|
||||
|
||||
assert args.phase in profile_data, (
|
||||
f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||
f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
|
||||
) # noqa
|
||||
|
||||
if args.table == "summary":
|
||||
entries_and_depths = flatten_entries(
|
||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]
|
||||
)
|
||||
column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15)
|
||||
elif args.table == "model":
|
||||
entries_and_depths = flatten_entries(
|
||||
ModelStatsEntry, profile_data[args.phase]["model_stats"]
|
||||
)
|
||||
column_widths = dict(
|
||||
name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
|
||||
)
|
||||
|
||||
# indent entry names based on the depth
|
||||
entries = []
|
||||
for entry, depth in entries_and_depths:
|
||||
entry.name = indent_string(
|
||||
entry.name,
|
||||
indent=depth,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " ",
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
TablePrinter(type(entries[0]), column_widths).print_table(entries)
|
||||
631
tools/profiler/visualize_layerwise_profile.py
Normal file
631
tools/profiler/visualize_layerwise_profile.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
## JSON parsing utils ####
|
||||
|
||||
|
||||
def largest_dist_from_leaf(node: dict, depth: int = 0):
|
||||
if len(node["children"]) == 0:
|
||||
return depth
|
||||
return max(
|
||||
[largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]]
|
||||
)
|
||||
|
||||
|
||||
def get_entries_at_depth(
|
||||
depth: int,
|
||||
entries_and_traces: list[tuple[Any, Any]],
|
||||
node: dict,
|
||||
curr_depth: int = 0,
|
||||
trace=(),
|
||||
):
|
||||
# assert that the query is at kernel or module level
|
||||
assert depth == -1 or depth == -2
|
||||
|
||||
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
|
||||
# The tree is not tall enough!
|
||||
entries_and_traces.append((node["entry"], trace))
|
||||
return
|
||||
|
||||
if largest_dist_from_leaf(node) == (abs(depth) - 1):
|
||||
entries_and_traces.append((node["entry"], trace))
|
||||
|
||||
trace = (node["entry"]["name"],) + trace
|
||||
for child in node["children"]:
|
||||
get_entries_at_depth(
|
||||
depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace
|
||||
)
|
||||
|
||||
|
||||
def fold_nodes(root: dict, nodes_to_fold: list[str]):
|
||||
stack: list[dict] = [root]
|
||||
while len(stack) != 0:
|
||||
node = stack.pop()
|
||||
if node["entry"]["name"] in nodes_to_fold:
|
||||
node["children"] = []
|
||||
continue
|
||||
for child in node["children"]:
|
||||
stack.append(child)
|
||||
return root
|
||||
|
||||
|
||||
## Operation name cleanup utils ####
|
||||
|
||||
|
||||
def trim_string_back(string: str, width: int) -> str:
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[:-offset]
|
||||
if len(string) > 3:
|
||||
string = string + "..."
|
||||
return string
|
||||
|
||||
|
||||
def shorten_plot_legend_strings(legend, max_char_len: int):
|
||||
for t in legend.get_texts():
|
||||
t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len))
|
||||
|
||||
|
||||
def abbreviate_known_names(name: str) -> str:
|
||||
abbreviations = {
|
||||
"MergedColumnParallelLinear": "MCPLinear",
|
||||
"QKVParallelLinear": "QKVPLinear",
|
||||
"RowParallelLinear": "RPLinear",
|
||||
"weight=": "w=",
|
||||
"bfloat16": "bf16",
|
||||
"float16": "f16",
|
||||
}
|
||||
for key, value in abbreviations.items():
|
||||
name = name.replace(key, value)
|
||||
return name
|
||||
|
||||
|
||||
def attempt_to_make_names_unique(entries_and_traces):
|
||||
names, non_unique_names = (set(), set())
|
||||
|
||||
def all_the_same(items) -> bool:
|
||||
return all(i == items[0] for i in items)
|
||||
|
||||
for entry, _ in entries_and_traces:
|
||||
if entry["name"] in names:
|
||||
non_unique_names.add(entry["name"])
|
||||
else:
|
||||
names.add(entry["name"])
|
||||
|
||||
for name in non_unique_names:
|
||||
entries_and_traces_with_name = [
|
||||
(entry, trace)
|
||||
for entry, trace in entries_and_traces
|
||||
if entry["name"] == name
|
||||
]
|
||||
|
||||
zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name]))
|
||||
first_trace_difference = next(
|
||||
(
|
||||
i
|
||||
for i, trace_eles in enumerate(zipped_traces)
|
||||
if not all_the_same(trace_eles)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if first_trace_difference is None:
|
||||
# can't create a unique name, leave the names as they
|
||||
# are they will get aggregated by the pivot_table call
|
||||
continue
|
||||
|
||||
for entry, trace in entries_and_traces_with_name:
|
||||
entry["name"] = " <- ".join(
|
||||
(entry["name"],) + trace[: first_trace_difference + 1]
|
||||
)
|
||||
|
||||
|
||||
## Operation grouping utils ####
|
||||
"""
|
||||
Group operations in the given dataframe by some high-level ops like,
|
||||
- gemms
|
||||
- attention
|
||||
- rms_norm
|
||||
etc.
|
||||
"""
|
||||
|
||||
|
||||
def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
def is_rms_norm(op_name: str):
|
||||
if "rms_norm_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_attention_block(op_name: str):
|
||||
if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_quant(op_name: str):
|
||||
if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name:
|
||||
return True
|
||||
|
||||
# LoRA ops
|
||||
def is_sgmv_shrink(op_name: str):
|
||||
return "sgmv_shrink" in op_name
|
||||
|
||||
def is_sgmv_expand(op_name: str):
|
||||
return "sgmv_expand" in op_name
|
||||
|
||||
def is_bgmv_shrink(op_name: str):
|
||||
return "bgmv_shrink" in op_name
|
||||
|
||||
def is_bgmv_expand(op_name: str):
|
||||
return "bgmv_expand" in op_name
|
||||
|
||||
def is_cutlass_gemm_op(op_name: str):
|
||||
return (
|
||||
"void cutlass::Kernel" in op_name
|
||||
or "void cutlass::device_kernel" in op_name
|
||||
)
|
||||
|
||||
def is_gemm_op(op_name: str):
|
||||
if is_quant(op_name):
|
||||
return False
|
||||
return (
|
||||
is_cutlass_gemm_op(op_name)
|
||||
or "xmma_gemm" in op_name
|
||||
or "gemv2T_kernel" in op_name
|
||||
or "splitKreduce" in op_name
|
||||
or "s16816gemm" in op_name
|
||||
)
|
||||
|
||||
def is_elementwise_op(op_name: str):
|
||||
return "elementwise_kernel" in op_name
|
||||
|
||||
def is_mem_op(op_name: str):
|
||||
return "memcpy" in op_name.lower() or "memset" in op_name.lower()
|
||||
|
||||
def is_vocab_embedding_op(op_name: str):
|
||||
return "vocabparallelembed" in op_name.lower()
|
||||
|
||||
# nccl ops
|
||||
def is_nccl_op(op_name: str):
|
||||
return "nccl" in op_name.lower()
|
||||
|
||||
def is_nccl_all_reduce(op_name: str):
|
||||
return is_nccl_op(op_name) and (
|
||||
"all_reduce" in op_name.lower() or "allreduce" in op_name.lower()
|
||||
)
|
||||
|
||||
def is_nccl_gather(op_name: str):
|
||||
return is_nccl_op(op_name) and "gather" in op_name.lower()
|
||||
|
||||
def is_nccl_broadcast(op_name: str):
|
||||
return is_nccl_op(op_name) and "broadcast" in op_name.lower()
|
||||
|
||||
# Reduce ops types
|
||||
def is_cross_device_reduce_1stage(op_name: str):
|
||||
return "cross_device_reduce_1stage" in op_name
|
||||
|
||||
def is_cross_device_reduce_2stage(op_name: str):
|
||||
return "cross_device_reduce_2stage" in op_name
|
||||
|
||||
def is_custom_ar_all_reduce(op_name: str):
|
||||
return "_C_custom_ar::all_reduce" in op_name
|
||||
|
||||
def is_reduce_kernel(op_name: str):
|
||||
return "reduce_kernel" in op_name
|
||||
|
||||
headers = list(trace_df)
|
||||
ops = copy.deepcopy(headers)
|
||||
|
||||
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
|
||||
ops = list(filter(lambda x: x not in attention_ops, ops))
|
||||
|
||||
quant_ops = list(filter(lambda x: is_quant(x), ops))
|
||||
ops = list(filter(lambda x: x not in quant_ops, ops))
|
||||
|
||||
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
|
||||
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
|
||||
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
|
||||
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
|
||||
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
|
||||
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
|
||||
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
|
||||
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
|
||||
|
||||
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
|
||||
|
||||
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in gemm_ops, ops))
|
||||
|
||||
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
|
||||
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
|
||||
|
||||
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
|
||||
|
||||
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in mem_ops, ops))
|
||||
|
||||
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in elementwise_ops, ops))
|
||||
|
||||
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
|
||||
|
||||
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
|
||||
|
||||
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
|
||||
|
||||
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
|
||||
|
||||
cross_device_reduce_1stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_1stage(x), ops)
|
||||
)
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
|
||||
|
||||
cross_device_reduce_2stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops)
|
||||
)
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
|
||||
|
||||
custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
|
||||
|
||||
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
|
||||
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
|
||||
|
||||
if len(attention_ops):
|
||||
trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1)
|
||||
if len(quant_ops):
|
||||
trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1)
|
||||
|
||||
if len(sgmv_shrink_ops):
|
||||
trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1)
|
||||
if len(sgmv_expand_ops):
|
||||
trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1)
|
||||
if len(bgmv_shrink_ops):
|
||||
trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1)
|
||||
if len(bgmv_expand_ops):
|
||||
trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1)
|
||||
|
||||
if len(cutlass_gemm_ops):
|
||||
trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1)
|
||||
|
||||
if len(gemm_ops):
|
||||
trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||
if len(rms_norm_ops):
|
||||
trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1)
|
||||
if len(vocab_embed_ops):
|
||||
trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1)
|
||||
if len(mem_ops):
|
||||
trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1)
|
||||
if len(elementwise_ops):
|
||||
trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1)
|
||||
|
||||
if len(nccl_all_reduce_ops):
|
||||
trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg(
|
||||
"sum", axis=1
|
||||
)
|
||||
if len(nccl_gather_ops):
|
||||
trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1)
|
||||
if len(nccl_broadcast_ops):
|
||||
trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1)
|
||||
if len(nccl_other_ops):
|
||||
trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1)
|
||||
|
||||
if len(cross_device_reduce_1stage_ops):
|
||||
trace_df["cross_device_reduce_1stage_ops"] = trace_df[
|
||||
cross_device_reduce_1stage_ops
|
||||
].agg("sum", axis=1)
|
||||
if len(cross_device_reduce_2stage_ops):
|
||||
trace_df["cross_device_reduce_2stage_ops"] = trace_df[
|
||||
cross_device_reduce_2stage_ops
|
||||
].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_ops):
|
||||
trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg(
|
||||
"sum", axis=1
|
||||
)
|
||||
if len(reduce_kernel_ops):
|
||||
trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1)
|
||||
|
||||
trace_df.drop(
|
||||
attention_ops
|
||||
+ quant_ops
|
||||
+ sgmv_shrink_ops
|
||||
+ sgmv_expand_ops
|
||||
+ bgmv_shrink_ops
|
||||
+ bgmv_expand_ops
|
||||
+ cutlass_gemm_ops
|
||||
+ gemm_ops
|
||||
+ rms_norm_ops
|
||||
+ vocab_embed_ops
|
||||
+ mem_ops
|
||||
+ elementwise_ops
|
||||
+ nccl_all_reduce_ops
|
||||
+ nccl_gather_ops
|
||||
+ nccl_broadcast_ops
|
||||
+ nccl_other_ops
|
||||
+ cross_device_reduce_1stage_ops
|
||||
+ cross_device_reduce_2stage_ops
|
||||
+ custom_ar_all_reduce_ops
|
||||
+ reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True,
|
||||
)
|
||||
return trace_df
|
||||
|
||||
|
||||
## Data plotting utils ####
|
||||
|
||||
|
||||
def plot_trace_df(
|
||||
traces_df: "pd.DataFrame",
|
||||
plot_metric: str,
|
||||
plot_title: str,
|
||||
output: Path | None = None,
|
||||
):
|
||||
def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str:
|
||||
phase_df = traces_df.query(f'phase == "{phase}"')
|
||||
descs = phase_df["phase_desc"].to_list()
|
||||
assert all([desc == descs[0] for desc in descs])
|
||||
return descs[0]
|
||||
|
||||
phases = traces_df["phase"].unique()
|
||||
phase_descs = [get_phase_description(traces_df, p) for p in phases]
|
||||
traces_df = traces_df.pivot_table(
|
||||
index="phase", columns="name", values=plot_metric, aggfunc="sum"
|
||||
)
|
||||
|
||||
traces_df = group_trace_by_operations(traces_df)
|
||||
|
||||
# Make the figure
|
||||
fig_size_x = max(5, len(phases))
|
||||
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
|
||||
|
||||
# Draw the stacked bars
|
||||
ops = list(traces_df)
|
||||
bottom = [0] * len(phases)
|
||||
for op in ops:
|
||||
values = [traces_df[op][phase] for phase in phases]
|
||||
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
|
||||
ax.bar(phase_descs, values, label=op, bottom=bottom)
|
||||
bottom = [bottom[j] + values[j] for j in range(len(phases))]
|
||||
|
||||
# Write the values as text on the bars
|
||||
for bar in ax.patches:
|
||||
if bar.get_height() != 0:
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() / 2 + bar.get_y(),
|
||||
f"{round(bar.get_height(), 2)}",
|
||||
ha="center",
|
||||
color="w",
|
||||
weight="bold",
|
||||
size=5,
|
||||
)
|
||||
|
||||
# Setup legend
|
||||
handles, labels = plt.gca().get_legend_handles_labels()
|
||||
legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1))
|
||||
shorten_plot_legend_strings(legend, 50)
|
||||
|
||||
# Setup labels and title
|
||||
plt.setp(ax.get_xticklabels(), rotation=90)
|
||||
ax.set_ylabel(plot_metric)
|
||||
plt.suptitle(plot_title)
|
||||
|
||||
plt.savefig(output, bbox_inches="tight")
|
||||
print("Created: ", output)
|
||||
|
||||
|
||||
def main(
|
||||
json_trace: Path,
|
||||
output_directory: Path,
|
||||
depth: int, # Fetch/Plot operations at this depth of the Json tree
|
||||
plot_metric: str,
|
||||
make_names_unique: bool,
|
||||
top_k: int,
|
||||
json_nodes_to_fold: list[str],
|
||||
):
|
||||
def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame":
|
||||
def get_entries_and_traces(key: str):
|
||||
entries_and_traces: list[tuple[Any, Any]] = []
|
||||
for root in profile_json[key]["summary_stats"]:
|
||||
# Fold nodes in the traces as per user request. i.e. simply
|
||||
# make the requested nodes leaf-nodes.
|
||||
root = fold_nodes(root, json_nodes_to_fold)
|
||||
get_entries_at_depth(depth, entries_and_traces, root)
|
||||
return entries_and_traces
|
||||
|
||||
def keep_only_top_entries(
|
||||
df: "pd.DataFrame", metric: str, top_k: int = 9
|
||||
) -> "pd.DataFrame":
|
||||
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
|
||||
return df
|
||||
|
||||
def get_phase_description(key: str) -> str:
|
||||
num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"]
|
||||
if num_running_seqs is not None:
|
||||
return f"{key}-seqs-{num_running_seqs}"
|
||||
else:
|
||||
return key
|
||||
|
||||
# Get data for each key
|
||||
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
|
||||
|
||||
# Attempt some cleanup
|
||||
if make_names_unique:
|
||||
for trace in traces:
|
||||
attempt_to_make_names_unique(trace)
|
||||
|
||||
# To pandas dataframe
|
||||
trace_dfs = list(
|
||||
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces)
|
||||
)
|
||||
|
||||
# Respect top_k
|
||||
if top_k:
|
||||
trace_dfs = list(
|
||||
map(
|
||||
lambda trace_df: keep_only_top_entries(
|
||||
trace_df, "cuda_time_us", top_k
|
||||
),
|
||||
trace_dfs,
|
||||
)
|
||||
)
|
||||
|
||||
# Fill in information about the step-keys
|
||||
for trace_df, step_key in zip(trace_dfs, step_keys):
|
||||
trace_df["phase"] = step_key
|
||||
trace_df["phase_desc"] = get_phase_description(step_key)
|
||||
|
||||
# Combine all data frames so they can be put in a single plot
|
||||
traces_df = pd.concat(trace_dfs)
|
||||
|
||||
# Add a derived metric `cuda_time_ms`
|
||||
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
|
||||
traces_df = traces_df.fillna(0)
|
||||
|
||||
return traces_df
|
||||
|
||||
def make_plot_title_suffix(profile_json: dict) -> str:
|
||||
context = profile_json["context"]
|
||||
sparsity = context.get("sparsity", None)
|
||||
run_type = (
|
||||
f"Run {context['num_steps']} steps"
|
||||
if context["num_steps"]
|
||||
else (
|
||||
f"Complete {context['complete_num_requests_per_step']} per "
|
||||
f"step; Run till completion"
|
||||
)
|
||||
)
|
||||
return (
|
||||
f"{context['engine_args']['model']}\n"
|
||||
f"Batch={context['batch_size']}, "
|
||||
f"PromptLen={context['prompt_len']}, "
|
||||
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
|
||||
f"Run Type: {run_type}"
|
||||
)
|
||||
|
||||
profile_json = None
|
||||
with open(json_trace) as f:
|
||||
profile_json = json.load(f)
|
||||
assert profile_json is not None
|
||||
|
||||
# Get all `llm.generate.step()` profile
|
||||
step_traces = list(profile_json.keys())
|
||||
assert step_traces[0] == "context"
|
||||
step_traces = step_traces[1:] # have only prefill and decodes
|
||||
prefills = list(filter(lambda x: "prefill" in x, step_traces))
|
||||
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
|
||||
assert len(prefills) + len(all_decodes) == len(step_traces)
|
||||
assert len(prefills) == 1
|
||||
|
||||
decodes = all_decodes[:: args.step_plot_interval]
|
||||
if decodes[-1] != all_decodes[-1]:
|
||||
# Always have the last decode
|
||||
decodes.append(all_decodes[-1])
|
||||
|
||||
prefill_traces = prepare_data(profile_json, prefills)
|
||||
decode_traces = prepare_data(profile_json, decodes)
|
||||
|
||||
plot_title_suffix = make_plot_title_suffix(profile_json)
|
||||
|
||||
plot_trace_df(
|
||||
prefill_traces,
|
||||
plot_metric,
|
||||
"prefill " + plot_title_suffix,
|
||||
output_directory / Path("prefill.png"),
|
||||
)
|
||||
plot_trace_df(
|
||||
decode_traces,
|
||||
plot_metric,
|
||||
"decodes " + plot_title_suffix,
|
||||
output_directory / Path("decode_steps.png"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by \
|
||||
examples/offline_inference/profiling.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-directory", type=str, required=False, help="Directory to output plots"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--level", type=str, default="module", choices=["module", "kernel"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Only graph the top `top_k` entries by time.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fold-json-node",
|
||||
nargs="+",
|
||||
default=["Sampler", "LogitsProcessor"],
|
||||
help="Do not plot the children of these nodes. Let, \
|
||||
the node represent the aggregate of all its \
|
||||
children",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-metric",
|
||||
type=str,
|
||||
default="cuda_time_ms",
|
||||
help="Metric to plot. some options are cuda_time_ms, \
|
||||
pct_cuda_time",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step-plot-interval",
|
||||
type=int,
|
||||
default=4,
|
||||
help="For every `step_plot_interval` steps, plot 1 step",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare/Extract relevant args
|
||||
make_names_unique = False
|
||||
if args.level == "module":
|
||||
depth = -2
|
||||
make_names_unique = True
|
||||
elif args.level == "kernel":
|
||||
depth = -1
|
||||
else:
|
||||
raise Exception(f"Unexpected level value ({args.level})")
|
||||
|
||||
output_directory = (
|
||||
args.output_directory if args.output_directory else Path(args.json_trace).parent
|
||||
)
|
||||
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
|
||||
main(
|
||||
Path(args.json_trace),
|
||||
output_directory,
|
||||
depth,
|
||||
args.plot_metric,
|
||||
make_names_unique,
|
||||
args.top_k,
|
||||
args.fold_json_node,
|
||||
)
|
||||
Reference in New Issue
Block a user