latency test enhancement - part 1 (#909)
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
"""
|
||||
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
# Usage (latency test):
|
||||
# Usage (latency test) with dummy weights:
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
### Reference output:
|
||||
### Reference output (of the correctness test above, can be gpu dependent):
|
||||
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
||||
@@ -31,7 +31,9 @@ import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -47,25 +49,34 @@ from sglang.srt.utils import suppress_other_loggers
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
batch_size: int = 1
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: int = 1024
|
||||
output_len: int = 4
|
||||
result_filename: str = ""
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
||||
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
# use the default value's type to case the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def load_model(server_args, tp_rank):
|
||||
@@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs(bench_args, tokenizer):
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
@@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
||||
def prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len :]
|
||||
@@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs(bench_args, tokenizer):
|
||||
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
||||
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
@@ -179,7 +192,7 @@ def correctness_test(
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
@@ -187,7 +200,9 @@ def correctness_test(
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
||||
reqs = prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
)
|
||||
|
||||
# Extend
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
@@ -218,8 +233,13 @@ def latency_test(
|
||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||
)
|
||||
|
||||
# To make this PR easier to review, for now, only do the first element in batch_size tuple.
|
||||
bench_args.batch_size = bench_args.batch_size[0]
|
||||
|
||||
# Prepare inputs
|
||||
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size, bench_args.input_len
|
||||
)
|
||||
|
||||
def clear():
|
||||
model_runner.req_to_token_pool.clear()
|
||||
@@ -227,6 +247,11 @@ def latency_test(
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_once(output_len):
|
||||
measurement_results = {
|
||||
"batch_size": bench_args.batch_size,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
# Prefill
|
||||
torch.cuda.synchronize()
|
||||
tot_latency = 0
|
||||
@@ -239,6 +264,8 @@ def latency_test(
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
@@ -258,6 +285,8 @@ def latency_test(
|
||||
rank_print(
|
||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["avg_decode_latency"] = avg_decode_latency
|
||||
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
||||
|
||||
throughput = (
|
||||
(bench_args.input_len + bench_args.output_len)
|
||||
@@ -267,13 +296,22 @@ def latency_test(
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["total_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
clear()
|
||||
|
||||
# Run again
|
||||
run_once(bench_args.output_len)
|
||||
result_list = []
|
||||
result_list.append(run_once(bench_args.output_len))
|
||||
|
||||
# Write results in jsonlines format.
|
||||
if bench_args.result_filename:
|
||||
with jsonlines.open(bench_args.result_filename, "a") as f:
|
||||
f.write_all(result_list)
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
|
||||
Reference in New Issue
Block a user