support vlm model spec bench (#10173)
This commit is contained in:
@@ -16,8 +16,14 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.bench_serving import DatasetRow, benchmark, set_global_args
|
from sglang.bench_serving import (
|
||||||
|
DatasetRow,
|
||||||
|
benchmark,
|
||||||
|
sample_mmmu_requests,
|
||||||
|
set_global_args,
|
||||||
|
)
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
@@ -48,20 +54,33 @@ class FakeTokenizer:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def send_one_batch(base_url, num_prompts, batch_size):
|
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
|
||||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
|
||||||
:num_prompts
|
|
||||||
]
|
|
||||||
|
|
||||||
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||||
input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts]
|
if is_multimodal:
|
||||||
|
input_requests = sample_mmmu_requests(
|
||||||
|
num_prompts,
|
||||||
|
tokenizer,
|
||||||
|
512,
|
||||||
|
apply_chat_template=False,
|
||||||
|
)
|
||||||
|
backend = "sglang-oai-chat"
|
||||||
|
api_url = f"{base_url}/v1/chat/completions"
|
||||||
|
else:
|
||||||
|
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||||
|
:num_prompts
|
||||||
|
]
|
||||||
|
input_requests: List[DatasetRow] = [
|
||||||
|
DatasetRow(p, 0, 512) for p in padded_prompts
|
||||||
|
]
|
||||||
|
backend = "sglang"
|
||||||
|
api_url = f"{base_url}/generate"
|
||||||
|
|
||||||
# We need to set some dummy values in order to call `benchmark` below.
|
# We need to set some dummy values in order to call `benchmark` below.
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
backend="sglang",
|
backend=backend,
|
||||||
dataset_name="custom",
|
dataset_name="custom",
|
||||||
num_prompts=None,
|
num_prompts=None,
|
||||||
sharegpt_output_len=None,
|
sharegpt_output_len=None,
|
||||||
@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size):
|
|||||||
output_details=False,
|
output_details=False,
|
||||||
)
|
)
|
||||||
set_global_args(args)
|
set_global_args(args)
|
||||||
tokenizer = FakeTokenizer()
|
|
||||||
|
|
||||||
# Run benchmark
|
# Run benchmark
|
||||||
results = asyncio.run(
|
results = asyncio.run(
|
||||||
benchmark(
|
benchmark(
|
||||||
backend="sglang",
|
backend=backend,
|
||||||
api_url=f"{base_url}/generate",
|
api_url=api_url,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
model_id="default",
|
model_id="default",
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -143,8 +161,6 @@ def main(args, server_args):
|
|||||||
other_args = []
|
other_args = []
|
||||||
else:
|
else:
|
||||||
other_args = [
|
other_args = [
|
||||||
"--speculative-algorithm",
|
|
||||||
"EAGLE",
|
|
||||||
"--speculative-num-steps",
|
"--speculative-num-steps",
|
||||||
steps,
|
steps,
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
@@ -157,6 +173,8 @@ def main(args, server_args):
|
|||||||
[
|
[
|
||||||
"--speculative-draft-model-path",
|
"--speculative-draft-model-path",
|
||||||
server_args.speculative_draft_model_path,
|
server_args.speculative_draft_model_path,
|
||||||
|
"--speculative-algorithm",
|
||||||
|
server_args.speculative_algorithm,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -207,13 +225,23 @@ def main(args, server_args):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Warmup
|
# Warmup
|
||||||
send_one_batch(base_url, batch_size, batch_size)
|
send_one_batch(
|
||||||
|
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
|
||||||
|
)
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark
|
||||||
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||||
base_url, max(args.num_prompts, batch_size), batch_size
|
base_url,
|
||||||
|
max(args.num_prompts, batch_size),
|
||||||
|
batch_size,
|
||||||
|
tokenizer,
|
||||||
|
args.is_multimodal,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
kill_process_tree(process.pid)
|
kill_process_tree(process.pid)
|
||||||
@@ -273,6 +301,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--start", type=int, default=0)
|
parser.add_argument("--start", type=int, default=0)
|
||||||
parser.add_argument("--end", type=int)
|
parser.add_argument("--end", type=int)
|
||||||
parser.add_argument("--output", type=str, default="output.jsonl")
|
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||||
|
parser.add_argument("--is-multimodal", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user