Rename sglang.bench_latency to sglang.bench_one_batch (#2118)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -118,7 +118,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_latency.TestBenchLatency.test_default
|
||||
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default
|
||||
|
||||
- name: Benchmark online latency
|
||||
timeout-minutes: 10
|
||||
@@ -194,7 +194,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_latency.TestBenchLatency.test_moe_default
|
||||
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
|
||||
|
||||
accuracy-test-1-gpu:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
# Benchmark and Profiling
|
||||
|
||||
## Benchmark
|
||||
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`.
|
||||
- Benchmark the latency of running a single static batch without a server. The arguments are the same as for `launch_server.py`.
|
||||
Note that this is a simplified test script without a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this simplified script does not.
|
||||
```
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
||||
```
|
||||
- Benchmark online serving. Launch a server first and run the following command.
|
||||
- Benchmark offline processing. This script will start an offline engine and run the benchmark.
|
||||
```
|
||||
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
||||
```
|
||||
- Benchmark online serving. Please use `sglang.launch_server` to launch a server first and run the following command.
|
||||
```
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
||||
```
|
||||
@@ -23,7 +28,7 @@ apt update
|
||||
apt install nsight-systems-cli
|
||||
```
|
||||
|
||||
1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
|
||||
1. To profile a single batch, use `nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
|
||||
|
||||
2. To profile a server, e.g.
|
||||
|
||||
@@ -33,7 +38,7 @@ apt install nsight-systems-cli
|
||||
nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
|
||||
# client
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompts 6000 --dataset-name random --random-input 4096 --random-output 2048
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512
|
||||
```
|
||||
|
||||
3. Use NVTX, e.g.
|
||||
|
||||
@@ -59,7 +59,7 @@ For interactive debugging, you can compare the outputs of huggingface/transforme
|
||||
The following two commands should give the same text output and very similar prefill logits.
|
||||
|
||||
- Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]`
|
||||
- Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]`
|
||||
- Get the SGLang output by `python3 -m sglang.bench_one_batch --correct --model [new model]`
|
||||
|
||||
#### Add the model to the test suite
|
||||
To make sure the new model is well maintained in the future, it is better to add it to the test suite.
|
||||
|
||||
@@ -59,7 +59,7 @@ drun -p 30000:30000 \
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000
|
||||
|
||||
# Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default
|
||||
drun v0.3.5.post2-rocm620 python3 -m sglang.bench_latency --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
|
||||
drun v0.3.5.post2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
|
||||
```
|
||||
|
||||
## Method 4: Using docker compose
|
||||
|
||||
@@ -16,10 +16,13 @@ classifiers = [
|
||||
dependencies = ["requests", "tqdm", "numpy", "IPython"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
|
||||
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
|
||||
"outlines>=0.0.44,<0.1.0", "modelscope"]
|
||||
runtime_common = ["aiohttp", "decord", "fastapi",
|
||||
"hf_transfer", "huggingface_hub", "interegular",
|
||||
"orjson", "outlines>=0.0.44,<0.1.0",
|
||||
"packaging", "pillow", "prometheus-client>=0.20.0",
|
||||
"psutil", "pydantic", "python-multipart",
|
||||
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
|
||||
"modelscope"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
|
||||
- `test`: The test utilities.
|
||||
- `api.py`: The public APIs.
|
||||
- `bench_latency.py`: Benchmark the latency of running a single static batch.
|
||||
- `bench_server_latency.py`: Benchmark the latency of serving a single batch with a real server.
|
||||
- `bench_offline_throughput.py`: Benchmark the throughput in the offline mode.
|
||||
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
||||
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
||||
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
||||
- `check_env.py`: Check the environment variables.
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point for launching the local server.
|
||||
- `utils.py`: Common utilities.
|
||||
|
||||
@@ -1,553 +1 @@
|
||||
"""
|
||||
Benchmark the latency of running a single static batch.
|
||||
This script does not launch a server and uses the low-level APIs.
|
||||
It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl
|
||||
## do some changes, and store the results under a different run_name:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --result-filename out.jsonl --run-name after
|
||||
## plot the results in series of lines:
|
||||
python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results"
|
||||
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
## Reference output (of the correctness test above, can be gpu dependent):
|
||||
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
||||
|
||||
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
||||
device='cuda:0')
|
||||
|
||||
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
||||
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
||||
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
||||
device='cuda:0')
|
||||
|
||||
========== Prompt 0 ==========
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
|
||||
========== Prompt 1 ==========
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the
|
||||
|
||||
========== Prompt 2 ==========
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server import _set_envs_and_config
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_child_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "before"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
result_filename: str = ""
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
# Plotting args
|
||||
graph_sql: str = (
|
||||
"select run_name, batch_size, prefill_throughput from results where run_name='before'"
|
||||
)
|
||||
graph_filename: str = "out.png"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
# graphing
|
||||
parser.add_argument("--graph-sql", type=str, default=BenchArgs.graph_sql)
|
||||
parser.add_argument(
|
||||
"--graph-filename", type=str, default=BenchArgs.graph_filename
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to case the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def load_model(server_args, port_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.tp_size > 1:
|
||||
dist.barrier()
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text=prompts[i],
|
||||
origin_input_ids=tmp_input_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(input_ids)):
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text="",
|
||||
origin_input_ids=list(input_ids[i]),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||
rank_print(f"\n{input_ids=}\n")
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
)
|
||||
|
||||
# Extend
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
||||
|
||||
# Decode
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len[0] - 1):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
next_token_ids_list = next_token_ids.tolist()
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids_list[i])
|
||||
|
||||
# Print
|
||||
for i in range(len(reqs)):
|
||||
rank_print(f"========== Prompt {i} ==========")
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
def synchronize(device):
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
rank_print(
|
||||
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
||||
)
|
||||
return
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
# Prefill
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
synchronize(device)
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
synchronize(device)
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
decode_latencies.append(latency)
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
med_decode_latency = np.median(decode_latencies)
|
||||
med_decode_throughput = batch_size / med_decode_latency
|
||||
rank_print(
|
||||
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["median_decode_latency"] = med_decode_latency
|
||||
measurement_results["median_decode_throughput"] = med_decode_throughput
|
||||
|
||||
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["total_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs for warm up
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size[0], bench_args.input_len[0]
|
||||
)
|
||||
|
||||
# Warm up
|
||||
rank_print("Warmup ...")
|
||||
latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
8, # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
)
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
# Write results in jsonlines format on rank 0.
|
||||
if tp_rank == 0 and bench_args.result_filename:
|
||||
import jsonlines
|
||||
|
||||
with jsonlines.open(bench_args.result_filename, "a") as f:
|
||||
f.write_all(result_list)
|
||||
|
||||
|
||||
def plot_latency_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
assert tp_rank == 0
|
||||
|
||||
# read the jsonl file and put in sqlite
|
||||
df = pd.read_json(bench_args.result_filename, lines=True)
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cur = conn.cursor()
|
||||
|
||||
# get the columns and their types
|
||||
column_names = list(df.iloc[0].keys())
|
||||
type_dict = {
|
||||
str: "TEXT",
|
||||
np.int64: "INTEGER",
|
||||
np.float64: "FLOAT",
|
||||
}
|
||||
column_types = [type_dict[type(i)] for i in list(df.iloc[0])]
|
||||
|
||||
# create the table
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
{", ".join([f"{name} {type}" for name, type in zip(column_names, column_types)])}
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# write the results to DB
|
||||
df.to_sql("results", conn, if_exists="replace", index=False)
|
||||
conn.commit()
|
||||
|
||||
# read it back using sql
|
||||
df = pd.read_sql_query(bench_args.graph_sql, conn)
|
||||
conn.close()
|
||||
|
||||
# plot it and save to a file
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
assert (
|
||||
len(df.columns) == 3
|
||||
), f"The sql should have fetched <series, x, y> columns, not {df.columns}"
|
||||
for label in df[df.columns[0]].unique():
|
||||
q = f"{df.columns[0]}=='{label}'"
|
||||
series = df.query(q)
|
||||
plt.plot(series[df.columns[1]], series[df.columns[2]], label=q, marker="o")
|
||||
plt.xlabel(df.columns[1])
|
||||
plt.ylabel(df.columns[2])
|
||||
plt.legend()
|
||||
plt.savefig(bench_args.graph_filename, dpi=300)
|
||||
|
||||
# if in kitty, just dump it to the terminal
|
||||
if os.environ["TERM"] == "xterm-kitty":
|
||||
os.system(
|
||||
f"kitty icat --use-window-size 1,1,600,600 {bench_args.graph_filename}"
|
||||
)
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
elif os.path.isfile(bench_args.result_filename):
|
||||
assert bench_args.graph_filename, "please provide a filename for the graph"
|
||||
work_func = plot_latency_test
|
||||
else:
|
||||
raise ValueError(
|
||||
"Provide --model-path for running the tests or "
|
||||
"provide --result-filename for plotting the results"
|
||||
)
|
||||
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
|
||||
if server_args.tp_size == 1:
|
||||
work_func(server_args, port_args, bench_args, 0)
|
||||
else:
|
||||
workers = []
|
||||
for tp_rank in range(server_args.tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=work_func,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
proc.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
main(server_args, bench_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process()
|
||||
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
"""
|
||||
Benchmark the throughput of using the offline LLM engine.
|
||||
This script does not launch a server.
|
||||
Benchmark the throughput in the offline mode.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
||||
|
||||
# Usage
|
||||
## Sharegpt dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
||||
|
||||
## Random dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
|
||||
|
||||
## Shared prefix dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
|
||||
|
||||
## Sharegpt dataset on runtime backend
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
474
python/sglang/bench_one_batch.py
Normal file
474
python/sglang/bench_one_batch.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Benchmark the latency of running a single static batch without a server.
|
||||
|
||||
This script does not launch a server and uses the low-level APIs.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
||||
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
## Reference output (of the correctness test above, can be gpu dependent):
|
||||
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
||||
|
||||
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
||||
device='cuda:0')
|
||||
|
||||
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
||||
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
||||
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
||||
device='cuda:0')
|
||||
|
||||
========== Prompt 0 ==========
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
|
||||
========== Prompt 1 ==========
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the
|
||||
|
||||
========== Prompt 2 ==========
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server import _set_envs_and_config
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_child_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
result_filename: str = "result.jsonl"
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to case the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def load_model(server_args, port_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.tp_size > 1:
|
||||
dist.barrier()
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text=prompts[i],
|
||||
origin_input_ids=tmp_input_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(input_ids)):
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text="",
|
||||
origin_input_ids=list(input_ids[i]),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||
rank_print(f"\n{input_ids=}\n")
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
)
|
||||
|
||||
# Extend (prefill w/ KV cache)
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
||||
|
||||
# Decode
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len[0] - 1):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
next_token_ids_list = next_token_ids.tolist()
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids_list[i])
|
||||
|
||||
# Print output texts
|
||||
for i in range(len(reqs)):
|
||||
rank_print(f"========== Prompt {i} ==========")
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
def synchronize(device):
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
rank_print(
|
||||
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
||||
)
|
||||
return
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
# Prefill
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
synchronize(device)
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
synchronize(device)
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
decode_latencies.append(latency)
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
med_decode_latency = np.median(decode_latencies)
|
||||
med_decode_throughput = batch_size / med_decode_latency
|
||||
rank_print(
|
||||
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["median_decode_latency"] = med_decode_latency
|
||||
measurement_results["median_decode_throughput"] = med_decode_throughput
|
||||
|
||||
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["overall_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs for warm up
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size[0], bench_args.input_len[0]
|
||||
)
|
||||
|
||||
# Warm up
|
||||
rank_print("Warmup ...")
|
||||
latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
8, # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
)
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
# Write results in jsonlines format on rank 0.
|
||||
if tp_rank == 0 and bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
for result in result_list:
|
||||
fout.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
else:
|
||||
raise ValueError(
|
||||
"Provide --model-path for running the tests or "
|
||||
"provide --result-filename for plotting the results"
|
||||
)
|
||||
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
|
||||
if server_args.tp_size == 1:
|
||||
work_func(server_args, port_args, bench_args, 0)
|
||||
else:
|
||||
workers = []
|
||||
for tp_rank in range(server_args.tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=work_func,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
proc.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
main(server_args, bench_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process()
|
||||
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Benchmark the latency of serving a single batch with a real server.
|
||||
Benchmark the latency of running a single batch with a server.
|
||||
|
||||
This script launches a server and uses the HTTP interface.
|
||||
It accepts arguments similar to those of launch_server.py.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
Usage:
|
||||
|
||||
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
|
||||
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"transformers",
|
||||
"requests",
|
||||
"tqdm",
|
||||
"torchao",
|
||||
"numpy",
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
"packaging",
|
||||
"PIL",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
"multipart",
|
||||
"zmq",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"zmq",
|
||||
"vllm",
|
||||
"outlines",
|
||||
"multipart",
|
||||
"openai",
|
||||
"tiktoken",
|
||||
"anthropic",
|
||||
|
||||
@@ -30,10 +30,10 @@ device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
|
||||
tensor_parallel(model, device_mesh)
|
||||
```
|
||||
|
||||
An end-to-end example can be found in `python/sglang/bench_latency.py`.
|
||||
An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
|
||||
You can run it with the following command:
|
||||
```bash
|
||||
$ python3 -m sglang.bench_latency --correct \
|
||||
$ python3 -m sglang.bench_one_batch --correct \
|
||||
--model meta-llama/Meta-Llama-3-8B \
|
||||
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
|
||||
--tensor-parallel-size 2 \
|
||||
|
||||
@@ -579,11 +579,11 @@ def run_bench_serving(
|
||||
return res
|
||||
|
||||
|
||||
def run_bench_latency(model, other_args):
|
||||
def run_bench_one_batch(model, other_args):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.bench_latency",
|
||||
"sglang.bench_one_batch",
|
||||
"--model-path",
|
||||
model,
|
||||
"--batch-size",
|
||||
|
||||
@@ -4,19 +4,19 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
is_in_ci,
|
||||
run_bench_latency,
|
||||
run_bench_one_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestBenchLatency(unittest.TestCase):
|
||||
class TestBenchOneBatch(unittest.TestCase):
|
||||
def test_default(self):
|
||||
output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, [])
|
||||
output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, [])
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(output_throughput, 135)
|
||||
|
||||
def test_moe_default(self):
|
||||
output_throughput = run_bench_latency(
|
||||
output_throughput = run_bench_one_batch(
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_utils import is_in_ci, run_bench_latency
|
||||
from sglang.test.test_utils import is_in_ci, run_bench_one_batch
|
||||
|
||||
|
||||
class TestTorchTP(unittest.TestCase):
|
||||
def test_torch_native_llama(self):
|
||||
output_throughput = run_bench_latency(
|
||||
output_throughput = run_bench_one_batch(
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
[
|
||||
"--tp",
|
||||
|
||||
@@ -14,13 +14,13 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_latency,
|
||||
run_bench_one_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestTritonAttnBackend(unittest.TestCase):
|
||||
def test_latency(self):
|
||||
output_throughput = run_bench_latency(
|
||||
output_throughput = run_bench_one_batch(
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
[
|
||||
"--attention-backend",
|
||||
|
||||
Reference in New Issue
Block a user