diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index efcc3a3a4..df3f2c5ea 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -118,7 +118,7 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_latency.TestBenchLatency.test_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default - name: Benchmark online latency timeout-minutes: 10 @@ -194,7 +194,7 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_latency.TestBenchLatency.test_moe_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default accuracy-test-1-gpu: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index c0f54957d..fe8fc5260 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -1,11 +1,16 @@ # Benchmark and Profiling ## Benchmark -- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`. +- Benchmark the latency of running a single static batch without a server. The arguments are the same as for `launch_server.py`. + Note that this is a simplified test script without a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this simplified script does not. ``` - python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 + python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 ``` -- Benchmark online serving. Launch a server first and run the following command. +- Benchmark offline processing. This script will start an offline engine and run the benchmark. + ``` + python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10 + ``` +- Benchmark online serving. Please use `sglang.launch_server` to launch a server first and run the following command. ``` python3 -m sglang.bench_serving --backend sglang --num-prompt 10 ``` @@ -23,7 +28,7 @@ apt update apt install nsight-systems-cli ``` -1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` +1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` 2. To profile a server, e.g. @@ -33,7 +38,7 @@ apt install nsight-systems-cli nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache # client -python3 -m sglang.bench_serving --backend sglang --num-prompts 6000 --dataset-name random --random-input 4096 --random-output 2048 +python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512 ``` 3. Use NVTX, e.g. diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 9f9a16332..dbf4f71a0 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -59,7 +59,7 @@ For interactive debugging, you can compare the outputs of huggingface/transforme The following two commands should give the same text output and very similar prefill logits. - Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]` -- Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]` +- Get the SGLang output by `python3 -m sglang.bench_one_batch --correct --model [new model]` #### Add the model to the test suite To make sure the new model is well maintained in the future, it is better to add it to the test suite. diff --git a/docs/start/install.md b/docs/start/install.md index fd3863305..ee93ded40 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -59,7 +59,7 @@ drun -p 30000:30000 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.3.5.post2-rocm620 python3 -m sglang.bench_latency --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.3.5.post2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 0fe373106..1267ac7b3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,10 +16,13 @@ classifiers = [ dependencies = ["requests", "tqdm", "numpy", "IPython"] [project.optional-dependencies] -runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", - "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", - "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", - "outlines>=0.0.44,<0.1.0", "modelscope"] +runtime_common = ["aiohttp", "decord", "fastapi", + "hf_transfer", "huggingface_hub", "interegular", + "orjson", "outlines>=0.0.44,<0.1.0", + "packaging", "pillow", "prometheus-client>=0.20.0", + "psutil", "pydantic", "python-multipart", + "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", + "modelscope"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/README.md b/python/sglang/README.md index 8b59fc106..29a7149de 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -4,9 +4,11 @@ - `srt`: The backend engine for running local models. (SRT = SGLang Runtime). - `test`: The test utilities. - `api.py`: The public APIs. -- `bench_latency.py`: Benchmark the latency of running a single static batch. -- `bench_server_latency.py`: Benchmark the latency of serving a single batch with a real server. +- `bench_offline_throughput.py`: Benchmark the throughput in the offline mode. +- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server. +- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server. - `bench_serving.py`: Benchmark online serving with dynamic requests. +- `check_env.py`: Check the environment variables. - `global_config.py`: The global configs and constants. - `launch_server.py`: The entry point for launching the local server. - `utils.py`: Common utilities. diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 13bc113e6..f936ecc5b 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -1,553 +1 @@ -""" -Benchmark the latency of running a single static batch. -This script does not launch a server and uses the low-level APIs. -It accepts arguments similar to those of launch_server.py. - -# Usage (latency test) -## with dummy weights: -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy -## sweep through multiple data points and store (append) the results in a jsonl file: -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl -## do some changes, and store the results under a different run_name: -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after -## plot the results in series of lines: -python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results" - -# Usage (correctness test): -python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct - -## Reference output (of the correctness test above, can be gpu dependent): -input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] - -prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], - [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], - [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]], - device='cuda:0') - -prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141], - [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781], - [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]], - device='cuda:0') - -========== Prompt 0 ========== - The capital of France is Paris. -The capital of the United States is Washington, D.C. - - -========== Prompt 1 ========== - The capital of the United Kindom is London. -The capital of the United Kingdom is London. -The capital of the - -========== Prompt 2 ========== - Today is a sunny day and I like to go for a walk in the park. -I'm going to the park -""" - -import argparse -import dataclasses -import itertools -import json -import logging -import multiprocessing -import os -import sqlite3 -import time -from typing import Tuple - -import numpy as np -import pandas as pd -import torch -import torch.distributed as dist - -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server import _set_envs_and_config -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - configure_logger, - kill_child_process, - suppress_other_loggers, -) - - -@dataclasses.dataclass -class BenchArgs: - run_name: str = "before" - batch_size: Tuple[int] = (1,) - input_len: Tuple[int] = (1024,) - output_len: Tuple[int] = (16,) - result_filename: str = "" - correctness_test: bool = False - # This is only used for correctness test - cut_len: int = 4 - # Plotting args - graph_sql: str = ( - "select run_name, batch_size, prefill_throughput from results where run_name='before'" - ) - graph_filename: str = "out.png" - - @staticmethod - def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) - parser.add_argument( - "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size - ) - parser.add_argument( - "--input-len", type=int, nargs="+", default=BenchArgs.input_len - ) - parser.add_argument( - "--output-len", type=int, nargs="+", default=BenchArgs.output_len - ) - parser.add_argument( - "--result-filename", type=str, default=BenchArgs.result_filename - ) - parser.add_argument("--correctness-test", action="store_true") - parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) - # graphing - parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql) - parser.add_argument( - "--graph-filename", type=str, default=BenchArgs.graph_filename - ) - - @classmethod - def from_cli_args(cls, args: argparse.Namespace): - # use the default value's type to case the args into correct types. - attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] - return cls( - **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} - ) - - -def load_model(server_args, port_args, tp_rank): - suppress_other_loggers() - rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - ) - model_runner = ModelRunner( - model_config=model_config, - mem_fraction_static=server_args.mem_fraction_static, - gpu_id=tp_rank, - tp_rank=tp_rank, - tp_size=server_args.tp_size, - nccl_port=port_args.nccl_port, - server_args=server_args, - ) - rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") - tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - if server_args.tp_size > 1: - dist.barrier() - return model_runner, tokenizer - - -def prepare_inputs_for_correctness_test(bench_args, tokenizer): - prompts = [ - "The capital of France is", - "The capital of the United Kindom is", - "Today is a sunny day and I like", - ] - input_ids = [tokenizer.encode(p) for p in prompts] - sampling_params = SamplingParams( - temperature=0, - max_new_tokens=BenchArgs.output_len, - ) - - reqs = [] - for i in range(len(prompts)): - assert len(input_ids[i]) > bench_args.cut_len - - tmp_input_ids = input_ids[i][: bench_args.cut_len] - req = Req( - rid=i, - origin_input_text=prompts[i], - origin_input_ids=tmp_input_ids, - sampling_params=sampling_params, - ) - req.prefix_indices = [] - req.fill_ids = req.origin_input_ids - req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - reqs.append(req) - - return input_ids, reqs - - -def prepare_extend_inputs_for_correctness_test( - bench_args, input_ids, reqs, model_runner -): - for i in range(len(reqs)): - req = reqs[i] - req.fill_ids += input_ids[i][bench_args.cut_len :] - req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ - i, : bench_args.cut_len - ] - req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - return reqs - - -def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): - input_ids = np.ones((batch_size, input_len), dtype=np.int32) - sampling_params = SamplingParams( - temperature=0, - max_new_tokens=BenchArgs.output_len, - ) - - reqs = [] - for i in range(len(input_ids)): - req = Req( - rid=i, - origin_input_text="", - origin_input_ids=list(input_ids[i]), - sampling_params=sampling_params, - ) - req.prefix_indices = [] - req.fill_ids = req.origin_input_ids - req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - reqs.append(req) - - return reqs - - -@torch.no_grad -def extend(reqs, model_runner): - batch = ScheduleBatch.init_new( - reqs=reqs, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, - tree_cache=None, - model_config=model_runner.model_config, - ) - batch.prepare_for_extend() - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, forward_batch) - return next_token_ids, logits_output.next_token_logits, batch - - -@torch.no_grad -def decode(input_token_ids, batch, model_runner): - batch.output_ids = input_token_ids - batch.prepare_for_decode() - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) - logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, forward_batch) - return next_token_ids, logits_output.next_token_logits - - -def correctness_test( - server_args, - port_args, - bench_args, - tp_rank, -): - configure_logger(server_args, prefix=f" TP{tp_rank}") - rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - - # Load the model - model_runner, tokenizer = load_model(server_args, port_args, tp_rank) - - # Prepare inputs - input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) - rank_print(f"\n{input_ids=}\n") - - if bench_args.cut_len > 0: - # Prefill - next_token_ids, next_token_logits, batch = extend(reqs, model_runner) - rank_print(f"prefill logits (first half): {next_token_logits} \n") - - # Prepare extend inputs - reqs = prepare_extend_inputs_for_correctness_test( - bench_args, input_ids, reqs, model_runner - ) - - # Extend - next_token_ids, next_token_logits, batch = extend(reqs, model_runner) - rank_print(f"prefill logits (final): {next_token_logits} \n") - - # Decode - output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] - for _ in range(bench_args.output_len[0] - 1): - next_token_ids, _ = decode(next_token_ids, batch, model_runner) - next_token_ids_list = next_token_ids.tolist() - for i in range(len(reqs)): - output_ids[i].append(next_token_ids_list[i]) - - # Print - for i in range(len(reqs)): - rank_print(f"========== Prompt {i} ==========") - rank_print(tokenizer.decode(output_ids[i]), "\n") - - -def synchronize(device): - if device == "cuda": - torch.cuda.synchronize() - elif device == "xpu": - torch.xpu.synchronize() - - -def latency_test_run_once( - run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device -): - max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) - if batch_size > max_batch_size: - rank_print( - f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" - ) - return - - # Clear the pools. - model_runner.req_to_token_pool.clear() - model_runner.token_to_kv_pool.clear() - - measurement_results = { - "run_name": run_name, - "batch_size": batch_size, - "input_len": input_len, - "output_len": output_len, - } - - tot_latency = 0 - - # Prefill - synchronize(device) - tic = time.time() - next_token_ids, _, batch = extend(reqs, model_runner) - synchronize(device) - prefill_latency = time.time() - tic - tot_latency += prefill_latency - throughput = input_len * batch_size / prefill_latency - rank_print( - f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) - measurement_results["prefill_latency"] = prefill_latency - measurement_results["prefill_throughput"] = throughput - - # Decode - decode_latencies = [] - for i in range(output_len - 1): - synchronize(device) - tic = time.time() - next_token_ids, _ = decode(next_token_ids, batch, model_runner) - synchronize(device) - latency = time.time() - tic - tot_latency += latency - throughput = batch_size / latency - decode_latencies.append(latency) - if i < 5: - rank_print( - f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) - - # record decode timing from 2nd output - if output_len > 1: - med_decode_latency = np.median(decode_latencies) - med_decode_throughput = batch_size / med_decode_latency - rank_print( - f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" - ) - measurement_results["median_decode_latency"] = med_decode_latency - measurement_results["median_decode_throughput"] = med_decode_throughput - - throughput = (input_len + output_len) * batch_size / tot_latency - rank_print( - f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" - ) - measurement_results["total_latency"] = tot_latency - measurement_results["total_throughput"] = throughput - return measurement_results - - -def latency_test( - server_args, - port_args, - bench_args, - tp_rank, -): - configure_logger(server_args, prefix=f" TP{tp_rank}") - rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - - # Load the model - model_runner, tokenizer = load_model(server_args, port_args, tp_rank) - - # Prepare inputs for warm up - reqs = prepare_synthetic_inputs_for_latency_test( - bench_args.batch_size[0], bench_args.input_len[0] - ) - - # Warm up - rank_print("Warmup ...") - latency_test_run_once( - bench_args.run_name, - model_runner, - rank_print, - reqs, - bench_args.batch_size[0], - bench_args.input_len[0], - 8, # shorter decoding to speed up the warmup - server_args.device, - ) - rank_print("Benchmark ...") - - # Run the sweep - result_list = [] - for bs, il, ol in itertools.product( - bench_args.batch_size, bench_args.input_len, bench_args.output_len - ): - reqs = prepare_synthetic_inputs_for_latency_test(bs, il) - ret = latency_test_run_once( - bench_args.run_name, - model_runner, - rank_print, - reqs, - bs, - il, - ol, - server_args.device, - ) - if ret is not None: - result_list.append(ret) - - # Write results in jsonlines format on rank 0. - if tp_rank == 0 and bench_args.result_filename: - import jsonlines - - with jsonlines.open(bench_args.result_filename, "a") as f: - f.write_all(result_list) - - -def plot_latency_test( - server_args, - bench_args, - tp_rank, -): - assert tp_rank == 0 - - # read the jsonl file and put in sqlite - df = pd.read_json(bench_args.result_filename, lines=True) - conn = sqlite3.connect(":memory:") - cur = conn.cursor() - - # get the columns and their types - column_names = list(df.iloc[0].keys()) - type_dict = { - str: "TEXT", - np.int64: "INTEGER", - np.float64: "FLOAT", - } - column_types = [type_dict[type(i)] for i in list(df.iloc[0])] - - # create the table - cur.execute( - f""" - CREATE TABLE IF NOT EXISTS results ( - {", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])} - ) - """ - ) - conn.commit() - - # write the results to DB - df.to_sql("results", conn, if_exists="replace", index=False) - conn.commit() - - # read it back using sql - df = pd.read_sql_query(bench_args.graph_sql, conn) - conn.close() - - # plot it and save to a file - import matplotlib.pyplot as plt - - assert ( - len(df.columns) == 3 - ), f"The sql should have fetched columns, not {df.columns}" - for label in df[df.columns[0]].unique(): - q = f"{df.columns[0]}=='{label}'" - series = df.query(q) - plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o") - plt.xlabel(df.columns[1]) - plt.ylabel(df.columns[2]) - plt.legend() - plt.savefig(bench_args.graph_filename, dpi=300) - - # if in kitty, just dump it to the terminal - if os.environ["TERM"] == "xterm-kitty": - os.system( - f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}" - ) - - -def main(server_args, bench_args): - _set_envs_and_config(server_args) - - if server_args.model_path: - if bench_args.correctness_test: - work_func = correctness_test - else: - work_func = latency_test - elif os.path.isfile(bench_args.result_filename): - assert bench_args.graph_filename, "please provide a filename for the graph" - work_func = plot_latency_test - else: - raise ValueError( - "Provide --model-path for running the tests or " - "provide --result-filename for plotting the results" - ) - - port_args = PortArgs.init_new(server_args) - - if server_args.tp_size == 1: - work_func(server_args, port_args, bench_args, 0) - else: - workers = [] - for tp_rank in range(server_args.tp_size): - proc = multiprocessing.Process( - target=work_func, - args=( - server_args, - port_args, - bench_args, - tp_rank, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - proc.terminate() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - BenchArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) - bench_args = BenchArgs.from_cli_args(args) - - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) - - try: - main(server_args, bench_args) - except Exception as e: - raise e - finally: - kill_child_process() +raise ValueError("bench_latency.py has been renamed to bench_one_batch.py") diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index cb502fa02..f1c4e8f9e 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -1,20 +1,13 @@ """ -Benchmark the throughput of using the offline LLM engine. -This script does not launch a server. +Benchmark the throughput in the offline mode. It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py). # Usage ## Sharegpt dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10 ## Random dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random - -## Shared prefix dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix - -## Sharegpt dataset on runtime backend -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 """ import argparse diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py new file mode 100644 index 000000000..ea65b5383 --- /dev/null +++ b/python/sglang/bench_one_batch.py @@ -0,0 +1,474 @@ +""" +Benchmark the latency of running a single static batch without a server. + +This script does not launch a server and uses the low-level APIs. +It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). + +# Usage (latency test) +## with dummy weights: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy +## sweep through multiple data points and store (append) the results in a jsonl file: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run + +# Usage (correctness test): +python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct + +## Reference output (of the correctness test above, can be gpu dependent): +input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] + +prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]], + device='cuda:0') + +prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141], + [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781], + [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]], + device='cuda:0') + +========== Prompt 0 ========== + The capital of France is Paris. +The capital of the United States is Washington, D.C. + + +========== Prompt 1 ========== + The capital of the United Kindom is London. +The capital of the United Kingdom is London. +The capital of the + +========== Prompt 2 ========== + Today is a sunny day and I like to go for a walk in the park. +I'm going to the park +""" + +import argparse +import dataclasses +import itertools +import json +import logging +import multiprocessing +import time +from typing import Tuple + +import numpy as np +import torch +import torch.distributed as dist + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server import _set_envs_and_config +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + configure_logger, + kill_child_process, + suppress_other_loggers, +) + + +@dataclasses.dataclass +class BenchArgs: + run_name: str = "default" + batch_size: Tuple[int] = (1,) + input_len: Tuple[int] = (1024,) + output_len: Tuple[int] = (16,) + result_filename: str = "result.jsonl" + correctness_test: bool = False + # This is only used for correctness test + cut_len: int = 4 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) + parser.add_argument( + "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size + ) + parser.add_argument( + "--input-len", type=int, nargs="+", default=BenchArgs.input_len + ) + parser.add_argument( + "--output-len", type=int, nargs="+", default=BenchArgs.output_len + ) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) + parser.add_argument("--correctness-test", action="store_true") + parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to case the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) + + +def load_model(server_args, port_args, tp_rank): + suppress_other_loggers() + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + model_config = ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + ) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=tp_rank, + tp_rank=tp_rank, + tp_size=server_args.tp_size, + nccl_port=port_args.nccl_port, + server_args=server_args, + ) + rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") + tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + if server_args.tp_size > 1: + dist.barrier() + return model_runner, tokenizer + + +def prepare_inputs_for_correctness_test(bench_args, tokenizer): + prompts = [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + input_ids = [tokenizer.encode(p) for p in prompts] + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(prompts)): + assert len(input_ids[i]) > bench_args.cut_len + + tmp_input_ids = input_ids[i][: bench_args.cut_len] + req = Req( + rid=i, + origin_input_text=prompts[i], + origin_input_ids=tmp_input_ids, + sampling_params=sampling_params, + ) + req.prefix_indices = [] + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + reqs.append(req) + + return input_ids, reqs + + +def prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner +): + for i in range(len(reqs)): + req = reqs[i] + req.fill_ids += input_ids[i][bench_args.cut_len :] + req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ + i, : bench_args.cut_len + ] + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + return reqs + + +def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): + input_ids = np.ones((batch_size, input_len), dtype=np.int32) + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(input_ids)): + req = Req( + rid=i, + origin_input_text="", + origin_input_ids=list(input_ids[i]), + sampling_params=sampling_params, + ) + req.prefix_indices = [] + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + reqs.append(req) + + return reqs + + +@torch.no_grad +def extend(reqs, model_runner): + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + tree_cache=None, + model_config=model_runner.model_config, + ) + batch.prepare_for_extend() + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + logits_output = model_runner.forward(forward_batch) + next_token_ids = model_runner.sample(logits_output, forward_batch) + return next_token_ids, logits_output.next_token_logits, batch + + +@torch.no_grad +def decode(input_token_ids, batch, model_runner): + batch.output_ids = input_token_ids + batch.prepare_for_decode() + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + logits_output = model_runner.forward(forward_batch) + next_token_ids = model_runner.sample(logits_output, forward_batch) + return next_token_ids, logits_output.next_token_logits + + +def correctness_test( + server_args, + port_args, + bench_args, + tp_rank, +): + # Configure the logger + configure_logger(server_args, prefix=f" TP{tp_rank}") + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) + + # Prepare inputs + input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + rank_print(f"\n{input_ids=}\n") + + if bench_args.cut_len > 0: + # Prefill + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print(f"prefill logits (first half): {next_token_logits} \n") + + # Prepare extend inputs + reqs = prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner + ) + + # Extend (prefill w/ KV cache) + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print(f"prefill logits (final): {next_token_logits} \n") + + # Decode + output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] + for _ in range(bench_args.output_len[0] - 1): + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + next_token_ids_list = next_token_ids.tolist() + for i in range(len(reqs)): + output_ids[i].append(next_token_ids_list[i]) + + # Print output texts + for i in range(len(reqs)): + rank_print(f"========== Prompt {i} ==========") + rank_print(tokenizer.decode(output_ids[i]), "\n") + + +def synchronize(device): + if device == "cuda": + torch.cuda.synchronize() + elif device == "xpu": + torch.xpu.synchronize() + + +def latency_test_run_once( + run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device +): + max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) + if batch_size > max_batch_size: + rank_print( + f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" + ) + return + + # Clear the pools. + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool.clear() + + measurement_results = { + "run_name": run_name, + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + } + + tot_latency = 0 + + # Prefill + synchronize(device) + tic = time.time() + next_token_ids, _, batch = extend(reqs, model_runner) + synchronize(device) + prefill_latency = time.time() - tic + tot_latency += prefill_latency + throughput = input_len * batch_size / prefill_latency + rank_print( + f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["prefill_latency"] = prefill_latency + measurement_results["prefill_throughput"] = throughput + + # Decode + decode_latencies = [] + for i in range(output_len - 1): + synchronize(device) + tic = time.time() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + synchronize(device) + latency = time.time() - tic + tot_latency += latency + throughput = batch_size / latency + decode_latencies.append(latency) + if i < 5: + rank_print( + f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + + # Record decode timing from 2nd output + if output_len > 1: + med_decode_latency = np.median(decode_latencies) + med_decode_throughput = batch_size / med_decode_latency + rank_print( + f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" + ) + measurement_results["median_decode_latency"] = med_decode_latency + measurement_results["median_decode_throughput"] = med_decode_throughput + + throughput = (input_len + output_len) * batch_size / tot_latency + rank_print( + f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["total_latency"] = tot_latency + measurement_results["overall_throughput"] = throughput + return measurement_results + + +def latency_test( + server_args, + port_args, + bench_args, + tp_rank, +): + # Configure the logger + configure_logger(server_args, prefix=f" TP{tp_rank}") + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) + + # Prepare inputs for warm up + reqs = prepare_synthetic_inputs_for_latency_test( + bench_args.batch_size[0], bench_args.input_len[0] + ) + + # Warm up + rank_print("Warmup ...") + latency_test_run_once( + bench_args.run_name, + model_runner, + rank_print, + reqs, + bench_args.batch_size[0], + bench_args.input_len[0], + 8, # shorter decoding to speed up the warmup + server_args.device, + ) + rank_print("Benchmark ...") + + # Run the sweep + result_list = [] + for bs, il, ol in itertools.product( + bench_args.batch_size, bench_args.input_len, bench_args.output_len + ): + reqs = prepare_synthetic_inputs_for_latency_test(bs, il) + ret = latency_test_run_once( + bench_args.run_name, + model_runner, + rank_print, + reqs, + bs, + il, + ol, + server_args.device, + ) + if ret is not None: + result_list.append(ret) + + # Write results in jsonlines format on rank 0. + if tp_rank == 0 and bench_args.result_filename: + with open(bench_args.result_filename, "a") as fout: + for result in result_list: + fout.write(json.dumps(result) + "\n") + + +def main(server_args, bench_args): + _set_envs_and_config(server_args) + + if server_args.model_path: + if bench_args.correctness_test: + work_func = correctness_test + else: + work_func = latency_test + else: + raise ValueError( + "Provide --model-path for running the tests or " + "provide --result-filename for plotting the results" + ) + + port_args = PortArgs.init_new(server_args) + + if server_args.tp_size == 1: + work_func(server_args, port_args, bench_args, 0) + else: + workers = [] + for tp_rank in range(server_args.tp_size): + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + port_args, + bench_args, + tp_rank, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + proc.terminate() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + main(server_args, bench_args) + except Exception as e: + raise e + finally: + kill_child_process() diff --git a/python/sglang/bench_server_latency.py b/python/sglang/bench_one_batch_server.py similarity index 96% rename from python/sglang/bench_server_latency.py rename to python/sglang/bench_one_batch_server.py index f76682c9f..9737b8bd2 100644 --- a/python/sglang/bench_server_latency.py +++ b/python/sglang/bench_one_batch_server.py @@ -1,10 +1,10 @@ """ -Benchmark the latency of serving a single batch with a real server. +Benchmark the latency of running a single batch with a server. + This script launches a server and uses the HTTP interface. -It accepts arguments similar to those of launch_server.py. +It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). Usage: - python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index 4db1f82fc..198c2e4d9 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -15,24 +15,21 @@ PACKAGE_LIST = [ "flashinfer", "triton", "transformers", - "requests", - "tqdm", + "torchao", "numpy", "aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", - "packaging", - "PIL", "psutil", "pydantic", + "multipart", + "zmq", "uvicorn", "uvloop", - "zmq", "vllm", "outlines", - "multipart", "openai", "tiktoken", "anthropic", diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 835158288..37153e8cc 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -30,10 +30,10 @@ device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,)) tensor_parallel(model, device_mesh) ``` -An end-to-end example can be found in `python/sglang/bench_latency.py`. +An end-to-end example can be found in `python/sglang/bench_one_batch.py`. You can run it with the following command: ```bash -$ python3 -m sglang.bench_latency --correct \ +$ python3 -m sglang.bench_one_batch --correct \ --model meta-llama/Meta-Llama-3-8B \ --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \ --tensor-parallel-size 2 \ diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 07f666e30..3f991b39b 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -579,11 +579,11 @@ def run_bench_serving( return res -def run_bench_latency(model, other_args): +def run_bench_one_batch(model, other_args): command = [ "python3", "-m", - "sglang.bench_latency", + "sglang.bench_one_batch", "--model-path", model, "--batch-size", diff --git a/test/srt/test_bench_latency.py b/test/srt/test_bench_one_batch.py similarity index 70% rename from test/srt/test_bench_latency.py rename to test/srt/test_bench_one_batch.py index e54f49088..c1bc98e8e 100644 --- a/test/srt/test_bench_latency.py +++ b/test/srt/test_bench_one_batch.py @@ -4,19 +4,19 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, is_in_ci, - run_bench_latency, + run_bench_one_batch, ) -class TestBenchLatency(unittest.TestCase): +class TestBenchOneBatch(unittest.TestCase): def test_default(self): - output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, []) + output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, []) if is_in_ci(): self.assertGreater(output_throughput, 135) def test_moe_default(self): - output_throughput = run_bench_latency( + output_throughput = run_bench_one_batch( DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] ) diff --git a/test/srt/test_torch_tp.py b/test/srt/test_torch_tp.py index db6b71b15..e17b212f6 100644 --- a/test/srt/test_torch_tp.py +++ b/test/srt/test_torch_tp.py @@ -1,11 +1,11 @@ import unittest -from sglang.test.test_utils import is_in_ci, run_bench_latency +from sglang.test.test_utils import is_in_ci, run_bench_one_batch class TestTorchTP(unittest.TestCase): def test_torch_native_llama(self): - output_throughput = run_bench_latency( + output_throughput = run_bench_one_batch( "meta-llama/Meta-Llama-3-8B", [ "--tp", diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py index 5a8187de5..a4d19bec0 100644 --- a/test/srt/test_triton_attention_backend.py +++ b/test/srt/test_triton_attention_backend.py @@ -14,13 +14,13 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, is_in_ci, popen_launch_server, - run_bench_latency, + run_bench_one_batch, ) class TestTritonAttnBackend(unittest.TestCase): def test_latency(self): - output_throughput = run_bench_latency( + output_throughput = run_bench_one_batch( DEFAULT_MODEL_NAME_FOR_TEST, [ "--attention-backend",