diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index f6e03a308..7d434782d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -490,7 +490,7 @@ def get_dataset(args, tokenizer): prompt_suffix=args.prompt_suffix, apply_chat_template=args.apply_chat_template, ) - elif args.dataset_name == "random": + elif args.dataset_name.startswith("random"): input_requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, @@ -498,6 +498,7 @@ def get_dataset(args, tokenizer): range_ratio=args.random_range_ratio, tokenizer=tokenizer, dataset_path=args.dataset_path, + random_sample=args.dataset_name == "random", ) elif args.dataset_name == "generated-shared-prefix": input_requests = sample_generated_shared_prefix_requests( @@ -687,6 +688,7 @@ def sample_random_requests( range_ratio: float, tokenizer: PreTrainedTokenizerBase, dataset_path: str, + random_sample: bool = True, ) -> List[Tuple[str, int, int]]: input_lens = np.random.randint( @@ -700,11 +702,15 @@ def sample_random_requests( size=num_prompts, ) - if True: + if random_sample: # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Download sharegpt if necessary if not os.path.isfile(dataset_path): + print( + "If you do not want to randomly sample from a dataset," + " please use --dataset-name random-ids." + ) dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. @@ -1223,7 +1229,7 @@ async def benchmark( output_file_name = args.output_file else: now = datetime.now().strftime("%m%d") - if args.dataset_name == "random": + if args.dataset_name.startswith("random"): output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" else: output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" @@ -1442,7 +1448,7 @@ if __name__ == "__main__": "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "random", "generated-shared-prefix"], + choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"], help="Name of the dataset to benchmark on.", ) parser.add_argument(