support vlm model spec bench (#10173)
This commit is contained in:
@@ -16,8 +16,14 @@ from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.bench_serving import DatasetRow, benchmark, set_global_args
|
||||
from sglang.bench_serving import (
|
||||
DatasetRow,
|
||||
benchmark,
|
||||
sample_mmmu_requests,
|
||||
set_global_args,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
@@ -48,20 +54,33 @@ class FakeTokenizer:
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, num_prompts, batch_size):
|
||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||
:num_prompts
|
||||
]
|
||||
|
||||
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
|
||||
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||
input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts]
|
||||
if is_multimodal:
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_prompts,
|
||||
tokenizer,
|
||||
512,
|
||||
apply_chat_template=False,
|
||||
)
|
||||
backend = "sglang-oai-chat"
|
||||
api_url = f"{base_url}/v1/chat/completions"
|
||||
else:
|
||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||
:num_prompts
|
||||
]
|
||||
input_requests: List[DatasetRow] = [
|
||||
DatasetRow(p, 0, 512) for p in padded_prompts
|
||||
]
|
||||
backend = "sglang"
|
||||
api_url = f"{base_url}/generate"
|
||||
|
||||
# We need to set some dummy values in order to call `benchmark` below.
|
||||
args = SimpleNamespace(
|
||||
disable_ignore_eos=False,
|
||||
disable_stream=False,
|
||||
return_logprob=False,
|
||||
backend="sglang",
|
||||
backend=backend,
|
||||
dataset_name="custom",
|
||||
num_prompts=None,
|
||||
sharegpt_output_len=None,
|
||||
@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size):
|
||||
output_details=False,
|
||||
)
|
||||
set_global_args(args)
|
||||
tokenizer = FakeTokenizer()
|
||||
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="sglang",
|
||||
api_url=f"{base_url}/generate",
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
tokenizer=tokenizer,
|
||||
@@ -143,8 +161,6 @@ def main(args, server_args):
|
||||
other_args = []
|
||||
else:
|
||||
other_args = [
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-num-steps",
|
||||
steps,
|
||||
"--speculative-eagle-topk",
|
||||
@@ -157,6 +173,8 @@ def main(args, server_args):
|
||||
[
|
||||
"--speculative-draft-model-path",
|
||||
server_args.speculative_draft_model_path,
|
||||
"--speculative-algorithm",
|
||||
server_args.speculative_algorithm,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -207,13 +225,23 @@ def main(args, server_args):
|
||||
},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
# Warmup
|
||||
send_one_batch(base_url, batch_size, batch_size)
|
||||
send_one_batch(
|
||||
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||
base_url, max(args.num_prompts, batch_size), batch_size
|
||||
base_url,
|
||||
max(args.num_prompts, batch_size),
|
||||
batch_size,
|
||||
tokenizer,
|
||||
args.is_multimodal,
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
@@ -273,6 +301,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--start", type=int, default=0)
|
||||
parser.add_argument("--end", type=int)
|
||||
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||
parser.add_argument("--is-multimodal", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user