Support accurate length control for bench serving (#6594)
This commit is contained in:
@@ -340,7 +340,7 @@ async def async_request_sglang_generate(
|
|||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
payload = {
|
payload = {
|
||||||
"text": prompt,
|
("text" if isinstance(prompt, str) else "input_ids"): prompt,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"max_new_tokens": request_func_input.output_len,
|
"max_new_tokens": request_func_input.output_len,
|
||||||
@@ -494,6 +494,7 @@ def get_tokenizer(
|
|||||||
|
|
||||||
def get_dataset(args, tokenizer):
|
def get_dataset(args, tokenizer):
|
||||||
if args.dataset_name == "sharegpt":
|
if args.dataset_name == "sharegpt":
|
||||||
|
assert not args.tokenize_prompt
|
||||||
input_requests = sample_sharegpt_requests(
|
input_requests = sample_sharegpt_requests(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@@ -512,8 +513,10 @@ def get_dataset(args, tokenizer):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
random_sample=args.dataset_name == "random",
|
random_sample=args.dataset_name == "random",
|
||||||
|
return_text=not args.tokenize_prompt,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "generated-shared-prefix":
|
elif args.dataset_name == "generated-shared-prefix":
|
||||||
|
assert not args.tokenize_prompt
|
||||||
input_requests = sample_generated_shared_prefix_requests(
|
input_requests = sample_generated_shared_prefix_requests(
|
||||||
num_groups=args.gsp_num_groups,
|
num_groups=args.gsp_num_groups,
|
||||||
prompts_per_group=args.gsp_prompts_per_group,
|
prompts_per_group=args.gsp_prompts_per_group,
|
||||||
@@ -524,6 +527,7 @@ def get_dataset(args, tokenizer):
|
|||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "mmmu":
|
elif args.dataset_name == "mmmu":
|
||||||
|
assert not args.tokenize_prompt
|
||||||
input_requests = sample_mmmu_requests(
|
input_requests = sample_mmmu_requests(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
if not hasattr(args, "output_details"):
|
if not hasattr(args, "output_details"):
|
||||||
args.output_details = False
|
args.output_details = False
|
||||||
|
|
||||||
|
if not hasattr(args, "tokenize_prompt"):
|
||||||
|
args.tokenize_prompt = False
|
||||||
|
|
||||||
print(f"benchmark_args={args}")
|
print(f"benchmark_args={args}")
|
||||||
|
|
||||||
# Set global environments
|
# Set global environments
|
||||||
@@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
if args.extra_request_body:
|
if args.extra_request_body:
|
||||||
extra_request_body = json.loads(args.extra_request_body)
|
extra_request_body = json.loads(args.extra_request_body)
|
||||||
|
|
||||||
|
if args.tokenize_prompt:
|
||||||
|
assert (
|
||||||
|
args.backend == "sglang"
|
||||||
|
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
|
||||||
|
|
||||||
# Set url
|
# Set url
|
||||||
if args.port is None:
|
if args.port is None:
|
||||||
args.port = {
|
args.port = {
|
||||||
@@ -1812,6 +1824,11 @@ if __name__ == "__main__":
|
|||||||
default=1,
|
default=1,
|
||||||
help="Number of warmup requests to run before the benchmark",
|
help="Number of warmup requests to run before the benchmark",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenize-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user