From ebf69964cd7e5af9a079eb430ecdd7f67e8566d0 Mon Sep 17 00:00:00 2001 From: min-xu-et <168487304+min-xu-et@users.noreply.github.com> Date: Sun, 4 Aug 2024 18:15:23 -0700 Subject: [PATCH] latency test enhancement - final part (#921) --- python/pyproject.toml | 4 +- python/sglang/bench_latency.py | 178 ++++++++++++++++++++++++++------- 2 files changed, 146 insertions(+), 36 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index f5cfe3388..fa444ea98 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,14 +20,16 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines", +srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +test = ["jsonlines", "matplotlib", "pandas"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +dev = ["sglang[all]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ee35960ba..3000b0bb9 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -1,13 +1,21 @@ """ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. -# Usage (latency test) with dummy weights: +# Usage (latency test) +## with dummy weights: python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy +## sweep through multiple data points and store (append) the results in a jsonl file: +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl +## do some changes, and store the results under a different run_name: +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after +## plot the results in series of lines: +python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results" + # Usage (correctness test): python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct -### Reference output (of the correctness test above, can be gpu dependent): +## Reference output (of the correctness test above, can be gpu dependent): prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], @@ -28,13 +36,16 @@ I'm going to the park import argparse import dataclasses +import itertools import logging import multiprocessing +import os +import sqlite3 import time from typing import Tuple -import jsonlines import numpy as np +import pandas as pd import torch import torch.distributed as dist @@ -49,26 +60,42 @@ from sglang.srt.utils import suppress_other_loggers @dataclasses.dataclass class BenchArgs: + run_name: str = "before" batch_size: Tuple[int] = (1,) - input_len: int = 1024 - output_len: int = 4 + input_len: Tuple[int] = (1024,) + output_len: Tuple[int] = (4,) result_filename: str = "" correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + # Plotting args + graph_sql: str = ( + "select run_name, batch_size, prefill_throughput from results where run_name='before'" + ) + graph_filename: str = "out.png" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) parser.add_argument( "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size ) - parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) - parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) + parser.add_argument( + "--input-len", type=int, nargs="+", default=BenchArgs.input_len + ) + parser.add_argument( + "--output-len", type=int, nargs="+", default=BenchArgs.output_len + ) parser.add_argument( "--result-filename", type=str, default=BenchArgs.result_filename ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + # graphing + parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql) + parser.add_argument( + "--graph-filename", type=str, default=BenchArgs.graph_filename + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -222,15 +249,21 @@ def correctness_test( @torch.inference_mode() def latency_test_run_once( - model_runner, rank_print, reqs, batch_size, input_len, output_len + run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len ): + max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) + if batch_size > max_batch_size: + rank_print( + f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" + ) + return # Clear the pools. model_runner.req_to_token_pool.clear() model_runner.token_to_kv_pool.clear() measurement_results = { - "run_name": "before", + "run_name": run_name, "batch_size": batch_size, "input_len": input_len, "output_len": output_len, @@ -291,49 +324,119 @@ def latency_test( # Load the model model_runner, tokenizer = load_model(server_args, tp_rank) - rank_print( - f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" - ) - # To make this PR easier to review, for now, only do the first element in batch_size tuple. - bench_args.batch_size = bench_args.batch_size[0] - - # Prepare inputs + # Prepare inputs for warm up reqs = prepare_synthetic_inputs_for_latency_test( - bench_args.batch_size, bench_args.input_len + bench_args.batch_size[0], bench_args.input_len[0] ) # Warm up latency_test_run_once( - model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4 + bench_args.run_name, + model_runner, + rank_print, + reqs, + bench_args.batch_size[0], + bench_args.input_len[0], + 4, # shorter decoding to speed up the warmup ) - # Run again + # Run the sweep result_list = [] - result_list.append( - latency_test_run_once( - model_runner, - rank_print, - reqs, - bench_args.batch_size, - bench_args.input_len, - bench_args.output_len, + for bs, il, ol in itertools.product( + bench_args.batch_size, bench_args.input_len, bench_args.output_len + ): + req = prepare_synthetic_inputs_for_latency_test(bs, il) + ret = latency_test_run_once( + bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol ) - ) + if ret is not None: + result_list.append(ret) + + # Write results in jsonlines format on rank 0. + if tp_rank == 0 and bench_args.result_filename: + import jsonlines - # Write results in jsonlines format. - if bench_args.result_filename: with jsonlines.open(bench_args.result_filename, "a") as f: f.write_all(result_list) -def main(server_args, bench_args): - print(bench_args) +def plot_latency_test( + server_args, + bench_args, + tp_rank, +): + assert tp_rank == 0 - if bench_args.correctness_test: - work_func = correctness_test + # read the jsonl file and put in sqlite + df = pd.read_json(bench_args.result_filename, lines=True) + conn = sqlite3.connect(":memory:") + cur = conn.cursor() + + # get the columns and their types + column_names = list(df.iloc[0].keys()) + type_dict = { + str: "TEXT", + np.int64: "INTEGER", + np.float64: "FLOAT", + } + column_types = [type_dict[type(i)] for i in list(df.iloc[0])] + + # create the table + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS results ( + {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])} + ) + """ + ) + conn.commit() + + # write the results to DB + df.to_sql("results", conn, if_exists="replace", index=False) + conn.commit() + + # read it back using sql + df = pd.read_sql_query(bench_args.graph_sql, conn) + conn.close() + + # plot it and save to a file + import matplotlib.pyplot as plt + + assert ( + len(df.columns) == 3 + ), f"The sql should have fetched columns, not {df.columns}" + for label in df[df.columns[0]].unique(): + q = f"{df.columns[0]}=='{label}'" + series = df.query(q) + plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o") + plt.xlabel(df.columns[1]) + plt.ylabel(df.columns[2]) + plt.legend() + plt.savefig(bench_args.graph_filename, dpi=300) + + # if in kitty, just dump it to the terminal + if os.environ["TERM"] == "xterm-kitty": + os.system( + f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}" + ) + + +def main(server_args, bench_args): + + if server_args.model_path: + if bench_args.correctness_test: + work_func = correctness_test + else: + work_func = latency_test + elif os.path.isfile(bench_args.result_filename): + assert bench_args.graph_filename, "please provide a filename for the graph" + work_func = plot_latency_test else: - work_func = latency_test + raise ValueError( + "Provide --model-path for running the tests or " + "provide --result-filename for plotting the results" + ) if server_args.tp_size == 1: work_func(server_args, bench_args, 0) @@ -361,6 +464,11 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) + # For this script, model-path is not required + assert ( + parser._actions[1].option_strings[0] == "--model-path" + ), "options changed, this code need to be updated" + parser._actions[1].required = False args = parser.parse_args() server_args = ServerArgs.from_cli_args(args)