minor: add dataset dump and questions shuffle (#2093)
This commit is contained in:
@@ -15,6 +15,7 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import resource
|
||||
import sys
|
||||
@@ -682,6 +683,11 @@ def sample_generated_shared_prefix_requests(
|
||||
output_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if args.generated_input_path and os.path.exists(args.generated_input_path):
|
||||
print(f"\nloading generated input data from {args.generated_input_path}")
|
||||
with open(args.generated_input_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
"""Generate benchmark requests with shared system prompts using random tokens."""
|
||||
# Generate system prompts for each group
|
||||
system_prompts = []
|
||||
@@ -695,6 +701,9 @@ def sample_generated_shared_prefix_requests(
|
||||
question = gen_prompt(tokenizer, question_len)
|
||||
questions.append(question)
|
||||
|
||||
# Shuffle questions
|
||||
random.shuffle(questions)
|
||||
|
||||
# Combine system prompts with questions
|
||||
input_requests = []
|
||||
total_input_tokens = 0
|
||||
@@ -723,6 +732,11 @@ def sample_generated_shared_prefix_requests(
|
||||
print(
|
||||
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
||||
)
|
||||
if args.generated_input_save_path:
|
||||
print(f"Saving generated input data to {args.generated_input_save_path}")
|
||||
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
|
||||
with open(args.generated_input_save_path, "wb") as f:
|
||||
pickle.dump(input_requests, f)
|
||||
|
||||
return input_requests
|
||||
|
||||
@@ -1331,6 +1345,16 @@ if __name__ == "__main__":
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated-input-save-path",
|
||||
type=str,
|
||||
help="Path to save generated input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated-input-path",
|
||||
type=str,
|
||||
help="Path to load previously generated input data",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
Reference in New Issue
Block a user