diff --git a/scripts/playground/bench_speculative.py b/scripts/playground/bench_speculative.py index f16ff4460..c89e99242 100644 --- a/scripts/playground/bench_speculative.py +++ b/scripts/playground/bench_speculative.py @@ -16,8 +16,14 @@ from types import SimpleNamespace import numpy as np import requests +from transformers import AutoTokenizer -from sglang.bench_serving import DatasetRow, benchmark, set_global_args +from sglang.bench_serving import ( + DatasetRow, + benchmark, + sample_mmmu_requests, + set_global_args, +) from sglang.srt.server_args import ServerArgs from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -48,20 +54,33 @@ class FakeTokenizer: return [] -def send_one_batch(base_url, num_prompts, batch_size): - padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[ - :num_prompts - ] - +def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal): # format: (prompt, input_len, output len). We set input_len as a dummy value 0. - input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts] + if is_multimodal: + input_requests = sample_mmmu_requests( + num_prompts, + tokenizer, + 512, + apply_chat_template=False, + ) + backend = "sglang-oai-chat" + api_url = f"{base_url}/v1/chat/completions" + else: + padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[ + :num_prompts + ] + input_requests: List[DatasetRow] = [ + DatasetRow(p, 0, 512) for p in padded_prompts + ] + backend = "sglang" + api_url = f"{base_url}/generate" # We need to set some dummy values in order to call `benchmark` below. args = SimpleNamespace( disable_ignore_eos=False, disable_stream=False, return_logprob=False, - backend="sglang", + backend=backend, dataset_name="custom", num_prompts=None, sharegpt_output_len=None, @@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size): output_details=False, ) set_global_args(args) - tokenizer = FakeTokenizer() # Run benchmark results = asyncio.run( benchmark( - backend="sglang", - api_url=f"{base_url}/generate", + backend=backend, + api_url=api_url, base_url=base_url, model_id="default", tokenizer=tokenizer, @@ -143,8 +161,6 @@ def main(args, server_args): other_args = [] else: other_args = [ - "--speculative-algorithm", - "EAGLE", "--speculative-num-steps", steps, "--speculative-eagle-topk", @@ -157,6 +173,8 @@ def main(args, server_args): [ "--speculative-draft-model-path", server_args.speculative_draft_model_path, + "--speculative-algorithm", + server_args.speculative_algorithm, ] ) @@ -207,13 +225,23 @@ def main(args, server_args): }, ) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, trust_remote_code=server_args.trust_remote_code + ) + try: # Warmup - send_one_batch(base_url, batch_size, batch_size) + send_one_batch( + base_url, batch_size, batch_size, tokenizer, args.is_multimodal + ) # Benchmark acc_length, step_time, speed, completion_tokens = send_one_batch( - base_url, max(args.num_prompts, batch_size), batch_size + base_url, + max(args.num_prompts, batch_size), + batch_size, + tokenizer, + args.is_multimodal, ) finally: kill_process_tree(process.pid) @@ -273,6 +301,7 @@ if __name__ == "__main__": parser.add_argument("--start", type=int, default=0) parser.add_argument("--end", type=int) parser.add_argument("--output", type=str, default="output.jsonl") + parser.add_argument("--is-multimodal", action="store_true", default=False) args = parser.parse_args() server_args: ServerArgs = ServerArgs.from_cli_args(args)