Allow passing extra request body to bench_offline_throughput.py (#2085)
This commit is contained in:
@@ -23,7 +23,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -55,6 +55,7 @@ class BenchArgs:
|
||||
gen_question_len: int = 128
|
||||
gen_output_len: int = 256
|
||||
disable_ignore_eos: bool = False
|
||||
extra_request_body: Optional[str] = None
|
||||
seed: int = 1
|
||||
do_not_exit: bool = False
|
||||
|
||||
@@ -143,6 +144,13 @@ class BenchArgs:
|
||||
default=BenchArgs.disable_ignore_eos,
|
||||
help="Disable ignore EOS token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--do-not-exit",
|
||||
@@ -161,6 +169,7 @@ def throughput_test_once(
|
||||
backend,
|
||||
reqs: List[Tuple[str, int, int]],
|
||||
ignore_eos: bool,
|
||||
extra_request_body: Dict,
|
||||
):
|
||||
measurement_results = {
|
||||
"backend": backend_name,
|
||||
@@ -180,6 +189,7 @@ def throughput_test_once(
|
||||
"temperature": 0,
|
||||
"max_new_tokens": r[2],
|
||||
"ignore_eos": ignore_eos,
|
||||
**extra_request_body,
|
||||
}
|
||||
for r in reqs
|
||||
]
|
||||
@@ -233,6 +243,11 @@ def throughput_test(
|
||||
random.seed(bench_args.seed)
|
||||
np.random.seed(bench_args.seed)
|
||||
|
||||
# Parse args
|
||||
extra_request_body = {}
|
||||
if bench_args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
# Read dataset
|
||||
input_requests = get_dataset(bench_args, tokenizer)
|
||||
|
||||
@@ -252,6 +267,7 @@ def throughput_test(
|
||||
backend=backend,
|
||||
reqs=warmup_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
|
||||
logging.info("\nBenchmark...")
|
||||
@@ -260,6 +276,7 @@ def throughput_test(
|
||||
backend=backend,
|
||||
reqs=input_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
|
||||
if bench_args.result_filename:
|
||||
|
||||
Reference in New Issue
Block a user