Allow passing extra request body to bench_offline_throughput.py (#2085)

This commit is contained in:
Lianmin Zheng
2024-11-18 14:59:15 -08:00
committed by GitHub
parent 80e2c4a8de
commit 3b44bbeecf

View File

@@ -23,7 +23,7 @@ import json
import logging import logging
import random import random
import time import time
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
@@ -55,6 +55,7 @@ class BenchArgs:
gen_question_len: int = 128 gen_question_len: int = 128
gen_output_len: int = 256 gen_output_len: int = 256
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
seed: int = 1 seed: int = 1
do_not_exit: bool = False do_not_exit: bool = False
@@ -143,6 +144,13 @@ class BenchArgs:
default=BenchArgs.disable_ignore_eos, default=BenchArgs.disable_ignore_eos,
help="Disable ignore EOS token", help="Disable ignore EOS token",
) )
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--do-not-exit", "--do-not-exit",
@@ -161,6 +169,7 @@ def throughput_test_once(
backend, backend,
reqs: List[Tuple[str, int, int]], reqs: List[Tuple[str, int, int]],
ignore_eos: bool, ignore_eos: bool,
extra_request_body: Dict,
): ):
measurement_results = { measurement_results = {
"backend": backend_name, "backend": backend_name,
@@ -180,6 +189,7 @@ def throughput_test_once(
"temperature": 0, "temperature": 0,
"max_new_tokens": r[2], "max_new_tokens": r[2],
"ignore_eos": ignore_eos, "ignore_eos": ignore_eos,
**extra_request_body,
} }
for r in reqs for r in reqs
] ]
@@ -233,6 +243,11 @@ def throughput_test(
random.seed(bench_args.seed) random.seed(bench_args.seed)
np.random.seed(bench_args.seed) np.random.seed(bench_args.seed)
# Parse args
extra_request_body = {}
if bench_args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Read dataset # Read dataset
input_requests = get_dataset(bench_args, tokenizer) input_requests = get_dataset(bench_args, tokenizer)
@@ -252,6 +267,7 @@ def throughput_test(
backend=backend, backend=backend,
reqs=warmup_requests, reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
) )
logging.info("\nBenchmark...") logging.info("\nBenchmark...")
@@ -260,6 +276,7 @@ def throughput_test(
backend=backend, backend=backend,
reqs=input_requests, reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
) )
if bench_args.result_filename: if bench_args.result_filename: