minor: add dataset dump and questions shuffle (#2093)

This commit is contained in:
Yineng Zhang
2024-11-20 06:07:27 +08:00
committed by GitHub
parent e57c3e12b8
commit 55bd97f3e5

View File

@@ -15,6 +15,7 @@ import argparse
import asyncio
import json
import os
import pickle
import random
import resource
import sys
@@ -682,6 +683,11 @@ def sample_generated_shared_prefix_requests(
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
if args.generated_input_path and os.path.exists(args.generated_input_path):
print(f"\nloading generated input data from {args.generated_input_path}")
with open(args.generated_input_path, "rb") as f:
return pickle.load(f)
"""Generate benchmark requests with shared system prompts using random tokens."""
# Generate system prompts for each group
system_prompts = []
@@ -695,6 +701,9 @@ def sample_generated_shared_prefix_requests(
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Shuffle questions
random.shuffle(questions)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
@@ -723,6 +732,11 @@ def sample_generated_shared_prefix_requests(
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
if args.generated_input_save_path:
print(f"Saving generated input data to {args.generated_input_save_path}")
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
with open(args.generated_input_save_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
@@ -1331,6 +1345,16 @@ if __name__ == "__main__":
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
"--generated-input-save-path",
type=str,
help="Path to save generated input data",
)
parser.add_argument(
"--generated-input-path",
type=str,
help="Path to load previously generated input data",
)
args = parser.parse_args()
run_benchmark(args)