latency test enhancement - final part (#921)
This commit is contained in:
@@ -20,14 +20,16 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torch", "uvicorn", "uvloop", "zmq",
|
||||
"vllm==0.5.3.post1", "outlines>=0.0.44"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
test = ["jsonlines", "matplotlib", "pandas"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
"""
|
||||
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
# Usage (latency test) with dummy weights:
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
|
||||
## do some changes, and store the results under a different run_name:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
|
||||
## plot the results in series of lines:
|
||||
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
|
||||
|
||||
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
### Reference output (of the correctness test above, can be gpu dependent):
|
||||
## Reference output (of the correctness test above, can be gpu dependent):
|
||||
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
||||
@@ -28,13 +36,16 @@ I'm going to the park
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "before"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: int = 1024
|
||||
output_len: int = 4
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (4,)
|
||||
result_filename: str = ""
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
# Plotting args
|
||||
graph_sql: str = (
|
||||
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
|
||||
)
|
||||
graph_filename: str = "out.png"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
||||
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
# graphing
|
||||
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
|
||||
parser.add_argument(
|
||||
"--graph-filename", type=str, default=BenchArgs.graph_filename
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -222,15 +249,21 @@ def correctness_test(
|
||||
|
||||
@torch.inference_mode()
|
||||
def latency_test_run_once(
|
||||
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
rank_print(
|
||||
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
||||
)
|
||||
return
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": "before",
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
@@ -291,49 +324,119 @@ def latency_test(
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
rank_print(
|
||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||
)
|
||||
|
||||
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
|
||||
bench_args.batch_size = bench_args.batch_size[0]
|
||||
|
||||
# Prepare inputs
|
||||
# Prepare inputs for warm up
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size, bench_args.input_len
|
||||
bench_args.batch_size[0], bench_args.input_len[0]
|
||||
)
|
||||
|
||||
# Warm up
|
||||
latency_test_run_once(
|
||||
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
4, # shorter decoding to speed up the warmup
|
||||
)
|
||||
|
||||
# Run again
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
result_list.append(
|
||||
latency_test_run_once(
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size,
|
||||
bench_args.input_len,
|
||||
bench_args.output_len,
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
req = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
|
||||
)
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
# Write results in jsonlines format on rank 0.
|
||||
if tp_rank == 0 and bench_args.result_filename:
|
||||
import jsonlines
|
||||
|
||||
# Write results in jsonlines format.
|
||||
if bench_args.result_filename:
|
||||
with jsonlines.open(bench_args.result_filename, "a") as f:
|
||||
f.write_all(result_list)
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
print(bench_args)
|
||||
def plot_latency_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
assert tp_rank == 0
|
||||
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
# read the jsonl file and put in sqlite
|
||||
df = pd.read_json(bench_args.result_filename, lines=True)
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cur = conn.cursor()
|
||||
|
||||
# get the columns and their types
|
||||
column_names = list(df.iloc[0].keys())
|
||||
type_dict = {
|
||||
str: "TEXT",
|
||||
np.int64: "INTEGER",
|
||||
np.float64: "FLOAT",
|
||||
}
|
||||
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
|
||||
|
||||
# create the table
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# write the results to DB
|
||||
df.to_sql("results", conn, if_exists="replace", index=False)
|
||||
conn.commit()
|
||||
|
||||
# read it back using sql
|
||||
df = pd.read_sql_query(bench_args.graph_sql, conn)
|
||||
conn.close()
|
||||
|
||||
# plot it and save to a file
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
assert (
|
||||
len(df.columns) == 3
|
||||
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
|
||||
for label in df[df.columns[0]].unique():
|
||||
q = f"{df.columns[0]}=='{label}'"
|
||||
series = df.query(q)
|
||||
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
|
||||
plt.xlabel(df.columns[1])
|
||||
plt.ylabel(df.columns[2])
|
||||
plt.legend()
|
||||
plt.savefig(bench_args.graph_filename, dpi=300)
|
||||
|
||||
# if in kitty, just dump it to the terminal
|
||||
if os.environ["TERM"] == "xterm-kitty":
|
||||
os.system(
|
||||
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
|
||||
)
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
elif os.path.isfile(bench_args.result_filename):
|
||||
assert bench_args.graph_filename, "please provide a filename for the graph"
|
||||
work_func = plot_latency_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
raise ValueError(
|
||||
"Provide --model-path for running the tests or "
|
||||
"provide --result-filename for plotting the results"
|
||||
)
|
||||
|
||||
if server_args.tp_size == 1:
|
||||
work_func(server_args, bench_args, 0)
|
||||
@@ -361,6 +464,11 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
# For this script, model-path is not required
|
||||
assert (
|
||||
parser._actions[1].option_strings[0] == "--model-path"
|
||||
), "options changed, this code need to be updated"
|
||||
parser._actions[1].required = False
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
Reference in New Issue
Block a user