From 7222e1dacc3c9ba8b0506be3c8d2999465145186 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 21 May 2025 17:08:43 +0800 Subject: [PATCH] Let bench_one_batch_server use sharegpt data to make expert distribution more natural (#5573) --- python/sglang/bench_one_batch_server.py | 27 ++++++++++++++++--------- python/sglang/bench_serving.py | 25 ++++++++++++++--------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index da091bf98..adb433ead 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -22,6 +22,7 @@ from typing import Tuple import numpy as np import requests +from sglang.bench_serving import get_tokenizer, sample_random_requests from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree @@ -117,16 +118,19 @@ def run_one_case( input_len_step_percentage: float, run_name: str, result_filename: str, + tokenizer, ): requests.post(url + "/flush_cache") - input_lens = [ - int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage)) - for i in range(batch_size) - ] - input_ids = [ - [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))] - for i in range(batch_size) - ] + input_requests = sample_random_requests( + input_len=input_len, + output_len=output_len, + num_prompts=batch_size, + range_ratio=1.0, + tokenizer=tokenizer, + dataset_path="", + random_sample=True, + return_text=False, + ) use_structured_outputs = False if use_structured_outputs: @@ -145,8 +149,7 @@ def run_one_case( response = requests.post( url + "/generate", json={ - # "text": texts, - "input_ids": input_ids, + "input_ids": [input_ids for input_ids, _, _ in input_requests], "sampling_params": { "temperature": temperature, "max_new_tokens": output_len, @@ -228,6 +231,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): else: proc, base_url = launch_server_process(server_args) + tokenizer_id = server_args.tokenizer_path or server_args.model_path + tokenizer = get_tokenizer(tokenizer_id) + # warmup if not bench_args.skip_warmup: print("=" * 8 + " Warmup Begin " + "=" * 8) @@ -241,6 +247,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): input_len_step_percentage=bench_args.input_len_step_percentage, run_name="", result_filename="", + tokenizer=tokenizer, ) print("=" * 8 + " Warmup End " + "=" * 8 + "\n") diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 1624aaacc..9bab23b61 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -471,6 +471,10 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + assert ( + pretrained_model_name_or_path is not None + and pretrained_model_name_or_path != "" + ) if pretrained_model_name_or_path.endswith( ".json" ) or pretrained_model_name_or_path.endswith(".model"): @@ -832,6 +836,7 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, random_sample: bool = True, + return_text: bool = True, ) -> List[DatasetRow]: input_lens = np.random.randint( max(int(input_len * range_ratio), 1), @@ -892,10 +897,12 @@ def sample_random_requests( else: ratio = (input_lens[i] + prompt_len - 1) // prompt_len input_ids = (prompt_token_ids * ratio)[: input_lens[i]] - prompt = tokenizer.decode(input_ids) + input_content = input_ids + if return_text: + input_content = tokenizer.decode(input_content) input_requests.append( DatasetRow( - prompt=prompt, + prompt=input_content, prompt_len=int(input_lens[i]), output_len=int(output_lens[i]), ) @@ -905,15 +912,15 @@ def sample_random_requests( offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): - prompt = tokenizer.decode( - [ - (offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i]) - ] - ) + input_content = [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + if return_text: + input_content = tokenizer.decode(input_content) input_requests.append( DatasetRow( - prompt=prompt, + prompt=input_content, prompt_len=int(input_lens[i]), output_len=int(output_lens[i]), )