diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index f557cae79..7e4ae3c43 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -1,5 +1,7 @@ """ -Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. +Benchmark the latency of running a single static batch. +This script does not launch a server and uses the low-level APIs. +It accepts arguments similar to those of launch_server.py. # Usage (latency test) ## with dummy weights: diff --git a/python/sglang/bench_server_latency.py b/python/sglang/bench_server_latency.py new file mode 100644 index 000000000..45852daf7 --- /dev/null +++ b/python/sglang/bench_server_latency.py @@ -0,0 +1,187 @@ +""" +Benchmark the latency of serving a single batch with a real server. +This script launches a server and uses the HTTP interface. +It accepts arguments similar to those of launch_server.py. + +Usage: + +python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 +""" + +import argparse +import dataclasses +import itertools +import json +import multiprocessing +import os +import time +from typing import Tuple + +import numpy as np +import requests + +from sglang.srt.server import launch_server +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_child_process + + +@dataclasses.dataclass +class BenchArgs: + run_name: str = "default" + batch_size: Tuple[int] = (1,) + input_len: Tuple[int] = (1024,) + output_len: Tuple[int] = (16,) + result_filename: str = "result.jsonl" + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) + parser.add_argument( + "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size + ) + parser.add_argument( + "--input-len", type=int, nargs="+", default=BenchArgs.input_len + ) + parser.add_argument( + "--output-len", type=int, nargs="+", default=BenchArgs.output_len + ) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to case the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) + + +def launch_server_internal(server_args): + try: + launch_server(server_args) + except Exception as e: + raise e + finally: + kill_child_process(os.getpid(), including_parent=False) + + +def launch_server_process(server_args: ServerArgs): + proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,)) + proc.start() + base_url = f"http://{server_args.host}:{server_args.port}" + timeout = 600 + + start_time = time.time() + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + } + response = requests.get(f"{base_url}/v1/models", headers=headers) + if response.status_code == 200: + return proc, base_url + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError("Server failed to start within the timeout period.") + + +def run_one_case( + url: str, + batch_size: int, + input_len: int, + output_len: int, + run_name: str, + result_filename: str, +): + input_ids = [ + [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))] + for _ in range(batch_size) + ] + + tic = time.time() + response = requests.post( + url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + }, + ) + latency = time.time() - tic + + _ = response.json() + output_throughput = batch_size * output_len / latency + overall_throughput = batch_size * (input_len + output_len) / latency + + print(f"batch size: {batch_size}") + print(f"latency: {latency:.2f} s") + print(f"output throughput: {output_throughput:.2f} token/s") + print(f"(input + output) throughput: {overall_throughput:.2f} token/s") + + if result_filename: + with open(result_filename, "a") as fout: + res = { + "run_name": run_name, + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + "latency": round(latency, 4), + "output_throughput": round(output_throughput, 2), + "overall_throughput": round(overall_throughput, 2), + } + fout.write(json.dumps(res) + "\n") + + +def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): + proc, base_url = launch_server_process(server_args) + + # warmup + run_one_case( + base_url, + batch_size=16, + input_len=1024, + output_len=16, + run_name="", + result_filename="", + ) + + # benchmark + try: + for bs, il, ol in itertools.product( + bench_args.batch_size, bench_args.input_len, bench_args.output_len + ): + run_one_case( + base_url, + bs, + il, + ol, + bench_args.run_name, + bench_args.result_filename, + ) + finally: + kill_child_process(proc.pid) + + print(f"\nResults are saved to {bench_args.result_filename}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + # For this script, model-path is not required + assert ( + parser._actions[1].option_strings[0] == "--model-path" + ), "options changed, this code need to be updated" + parser._actions[1].required = False + args = parser.parse_args() + + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + run_benchmark(server_args, bench_args) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index d51aee4ec..2f68b39bb 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -2,7 +2,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py """ -Benchmark online serving. +Benchmark online serving with dynamic requests. Usage: python3 -m sglang.bench_serving --backend sglang --num-prompt 10 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7eef08b71..ee4fedabb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,17 +26,6 @@ from sglang.srt.utils import is_hip logger = logging.getLogger(__name__) -class LoRAPathAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, {}) - for lora_path in values: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - getattr(namespace, self.dest)[name] = path - else: - getattr(namespace, self.dest)[lora_path] = lora_path - - @dataclasses.dataclass class ServerArgs: # Model and tokenizer @@ -619,3 +608,14 @@ class PortArgs: controller_port: int detokenizer_port: int nccl_ports: List[int] + + +class LoRAPathAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, {}) + for lora_path in values: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + getattr(namespace, self.dest)[name] = path + else: + getattr(namespace, self.dest)[lora_path] = lora_path diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 18ae2d8c3..960f9c2e0 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -44,7 +44,7 @@ def get_answer_value(answer_str): return INVALID -def main(args): +def run_eval(args): # Select backend set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) @@ -119,6 +119,12 @@ def main(args): # Dump results dump_state_text("tmp_output_gsm8k.txt", states) + return { + "accuracy": acc, + "latency": latency, + "output_throughput": output_throughput, + } + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -129,4 +135,4 @@ if __name__ == "__main__": parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000) args = parser.parse_args() - main(args) + run_eval(args)