Fix bench_serving with random-ids (#5214)
This commit is contained in:
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
|
|||||||
prompt_suffix=args.prompt_suffix,
|
prompt_suffix=args.prompt_suffix,
|
||||||
apply_chat_template=args.apply_chat_template,
|
apply_chat_template=args.apply_chat_template,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name.startswith("random"):
|
||||||
input_requests = sample_random_requests(
|
input_requests = sample_random_requests(
|
||||||
input_len=args.random_input_len,
|
input_len=args.random_input_len,
|
||||||
output_len=args.random_output_len,
|
output_len=args.random_output_len,
|
||||||
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
|
|||||||
range_ratio=args.random_range_ratio,
|
range_ratio=args.random_range_ratio,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
|
random_sample=args.dataset_name == "random",
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "generated-shared-prefix":
|
elif args.dataset_name == "generated-shared-prefix":
|
||||||
input_requests = sample_generated_shared_prefix_requests(
|
input_requests = sample_generated_shared_prefix_requests(
|
||||||
@@ -687,6 +688,7 @@ def sample_random_requests(
|
|||||||
range_ratio: float,
|
range_ratio: float,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
|
random_sample: bool = True,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
|
||||||
input_lens = np.random.randint(
|
input_lens = np.random.randint(
|
||||||
@@ -700,11 +702,15 @@ def sample_random_requests(
|
|||||||
size=num_prompts,
|
size=num_prompts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if True:
|
if random_sample:
|
||||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||||
|
|
||||||
# Download sharegpt if necessary
|
# Download sharegpt if necessary
|
||||||
if not os.path.isfile(dataset_path):
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(
|
||||||
|
"If you do not want to randomly sample from a dataset,"
|
||||||
|
" please use --dataset-name random-ids."
|
||||||
|
)
|
||||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
@@ -1223,7 +1229,7 @@ async def benchmark(
|
|||||||
output_file_name = args.output_file
|
output_file_name = args.output_file
|
||||||
else:
|
else:
|
||||||
now = datetime.now().strftime("%m%d")
|
now = datetime.now().strftime("%m%d")
|
||||||
if args.dataset_name == "random":
|
if args.dataset_name.startswith("random"):
|
||||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
||||||
else:
|
else:
|
||||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
||||||
@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="sharegpt",
|
default="sharegpt",
|
||||||
choices=["sharegpt", "random", "generated-shared-prefix"],
|
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user