diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index a3793f921..7717c16f0 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -23,7 +23,7 @@ import json import logging import random import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np @@ -55,6 +55,7 @@ class BenchArgs: gen_question_len: int = 128 gen_output_len: int = 256 disable_ignore_eos: bool = False + extra_request_body: Optional[str] = None seed: int = 1 do_not_exit: bool = False @@ -143,6 +144,13 @@ class BenchArgs: default=BenchArgs.disable_ignore_eos, help="Disable ignore EOS token", ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--do-not-exit", @@ -161,6 +169,7 @@ def throughput_test_once( backend, reqs: List[Tuple[str, int, int]], ignore_eos: bool, + extra_request_body: Dict, ): measurement_results = { "backend": backend_name, @@ -180,6 +189,7 @@ def throughput_test_once( "temperature": 0, "max_new_tokens": r[2], "ignore_eos": ignore_eos, + **extra_request_body, } for r in reqs ] @@ -233,6 +243,11 @@ def throughput_test( random.seed(bench_args.seed) np.random.seed(bench_args.seed) + # Parse args + extra_request_body = {} + if bench_args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + # Read dataset input_requests = get_dataset(bench_args, tokenizer) @@ -252,6 +267,7 @@ def throughput_test( backend=backend, reqs=warmup_requests, ignore_eos=not bench_args.disable_ignore_eos, + extra_request_body=extra_request_body, ) logging.info("\nBenchmark...") @@ -260,6 +276,7 @@ def throughput_test( backend=backend, reqs=input_requests, ignore_eos=not bench_args.disable_ignore_eos, + extra_request_body=extra_request_body, ) if bench_args.result_filename: