Allow passing extra request body to bench_offline_throughput.py (#2085)
This commit is contained in:
@@ -23,7 +23,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -55,6 +55,7 @@ class BenchArgs:
|
|||||||
gen_question_len: int = 128
|
gen_question_len: int = 128
|
||||||
gen_output_len: int = 256
|
gen_output_len: int = 256
|
||||||
disable_ignore_eos: bool = False
|
disable_ignore_eos: bool = False
|
||||||
|
extra_request_body: Optional[str] = None
|
||||||
seed: int = 1
|
seed: int = 1
|
||||||
do_not_exit: bool = False
|
do_not_exit: bool = False
|
||||||
|
|
||||||
@@ -143,6 +144,13 @@ class BenchArgs:
|
|||||||
default=BenchArgs.disable_ignore_eos,
|
default=BenchArgs.disable_ignore_eos,
|
||||||
help="Disable ignore EOS token",
|
help="Disable ignore EOS token",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--extra-request-body",
|
||||||
|
metavar='{"key1": "value1", "key2": "value2"}',
|
||||||
|
type=str,
|
||||||
|
help="Append given JSON object to the request payload. You can use this to specify"
|
||||||
|
"additional generate params like sampling params.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--do-not-exit",
|
"--do-not-exit",
|
||||||
@@ -161,6 +169,7 @@ def throughput_test_once(
|
|||||||
backend,
|
backend,
|
||||||
reqs: List[Tuple[str, int, int]],
|
reqs: List[Tuple[str, int, int]],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
|
extra_request_body: Dict,
|
||||||
):
|
):
|
||||||
measurement_results = {
|
measurement_results = {
|
||||||
"backend": backend_name,
|
"backend": backend_name,
|
||||||
@@ -180,6 +189,7 @@ def throughput_test_once(
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": r[2],
|
"max_new_tokens": r[2],
|
||||||
"ignore_eos": ignore_eos,
|
"ignore_eos": ignore_eos,
|
||||||
|
**extra_request_body,
|
||||||
}
|
}
|
||||||
for r in reqs
|
for r in reqs
|
||||||
]
|
]
|
||||||
@@ -233,6 +243,11 @@ def throughput_test(
|
|||||||
random.seed(bench_args.seed)
|
random.seed(bench_args.seed)
|
||||||
np.random.seed(bench_args.seed)
|
np.random.seed(bench_args.seed)
|
||||||
|
|
||||||
|
# Parse args
|
||||||
|
extra_request_body = {}
|
||||||
|
if bench_args.extra_request_body:
|
||||||
|
extra_request_body = json.loads(args.extra_request_body)
|
||||||
|
|
||||||
# Read dataset
|
# Read dataset
|
||||||
input_requests = get_dataset(bench_args, tokenizer)
|
input_requests = get_dataset(bench_args, tokenizer)
|
||||||
|
|
||||||
@@ -252,6 +267,7 @@ def throughput_test(
|
|||||||
backend=backend,
|
backend=backend,
|
||||||
reqs=warmup_requests,
|
reqs=warmup_requests,
|
||||||
ignore_eos=not bench_args.disable_ignore_eos,
|
ignore_eos=not bench_args.disable_ignore_eos,
|
||||||
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("\nBenchmark...")
|
logging.info("\nBenchmark...")
|
||||||
@@ -260,6 +276,7 @@ def throughput_test(
|
|||||||
backend=backend,
|
backend=backend,
|
||||||
reqs=input_requests,
|
reqs=input_requests,
|
||||||
ignore_eos=not bench_args.disable_ignore_eos,
|
ignore_eos=not bench_args.disable_ignore_eos,
|
||||||
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if bench_args.result_filename:
|
if bench_args.result_filename:
|
||||||
|
|||||||
Reference in New Issue
Block a user