Fix bench_serving with random-ids (#5214)
This commit is contained in:
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
|
||||
prompt_suffix=args.prompt_suffix,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
elif args.dataset_name.startswith("random"):
|
||||
input_requests = sample_random_requests(
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=args.dataset_name == "random",
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
@@ -687,6 +688,7 @@ def sample_random_requests(
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dataset_path: str,
|
||||
random_sample: bool = True,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
|
||||
input_lens = np.random.randint(
|
||||
@@ -700,11 +702,15 @@ def sample_random_requests(
|
||||
size=num_prompts,
|
||||
)
|
||||
|
||||
if True:
|
||||
if random_sample:
|
||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(
|
||||
"If you do not want to randomly sample from a dataset,"
|
||||
" please use --dataset-name random-ids."
|
||||
)
|
||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||
|
||||
# Load the dataset.
|
||||
@@ -1223,7 +1229,7 @@ async def benchmark(
|
||||
output_file_name = args.output_file
|
||||
else:
|
||||
now = datetime.now().strftime("%m%d")
|
||||
if args.dataset_name == "random":
|
||||
if args.dataset_name.startswith("random"):
|
||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
||||
else:
|
||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
||||
@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "random", "generated-shared-prefix"],
|
||||
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user