Let bench_one_batch_server use sharegpt data to make expert distribution more natural (#5573)

This commit is contained in:
fzyzcjy
2025-05-21 17:08:43 +08:00
committed by GitHub
parent 505eec4dc9
commit 7222e1dacc
2 changed files with 33 additions and 19 deletions

View File

@@ -22,6 +22,7 @@ from typing import Tuple
import numpy as np
import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
@@ -117,16 +118,19 @@ def run_one_case(
input_len_step_percentage: float,
run_name: str,
result_filename: str,
tokenizer,
):
requests.post(url + "/flush_cache")
input_lens = [
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
for i in range(batch_size)
]
input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
for i in range(batch_size)
]
input_requests = sample_random_requests(
input_len=input_len,
output_len=output_len,
num_prompts=batch_size,
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path="",
random_sample=True,
return_text=False,
)
use_structured_outputs = False
if use_structured_outputs:
@@ -145,8 +149,7 @@ def run_one_case(
response = requests.post(
url + "/generate",
json={
# "text": texts,
"input_ids": input_ids,
"input_ids": [input_ids for input_ids, _, _ in input_requests],
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
@@ -228,6 +231,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else:
proc, base_url = launch_server_process(server_args)
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
# warmup
if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)
@@ -241,6 +247,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
tokenizer=tokenizer,
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")