minor: add dataset dump and questions shuffle (#2093)
This commit is contained in:
@@ -15,6 +15,7 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
import resource
|
import resource
|
||||||
import sys
|
import sys
|
||||||
@@ -682,6 +683,11 @@ def sample_generated_shared_prefix_requests(
|
|||||||
output_len: int,
|
output_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
if args.generated_input_path and os.path.exists(args.generated_input_path):
|
||||||
|
print(f"\nloading generated input data from {args.generated_input_path}")
|
||||||
|
with open(args.generated_input_path, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
"""Generate benchmark requests with shared system prompts using random tokens."""
|
"""Generate benchmark requests with shared system prompts using random tokens."""
|
||||||
# Generate system prompts for each group
|
# Generate system prompts for each group
|
||||||
system_prompts = []
|
system_prompts = []
|
||||||
@@ -695,6 +701,9 @@ def sample_generated_shared_prefix_requests(
|
|||||||
question = gen_prompt(tokenizer, question_len)
|
question = gen_prompt(tokenizer, question_len)
|
||||||
questions.append(question)
|
questions.append(question)
|
||||||
|
|
||||||
|
# Shuffle questions
|
||||||
|
random.shuffle(questions)
|
||||||
|
|
||||||
# Combine system prompts with questions
|
# Combine system prompts with questions
|
||||||
input_requests = []
|
input_requests = []
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
@@ -723,6 +732,11 @@ def sample_generated_shared_prefix_requests(
|
|||||||
print(
|
print(
|
||||||
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
||||||
)
|
)
|
||||||
|
if args.generated_input_save_path:
|
||||||
|
print(f"Saving generated input data to {args.generated_input_save_path}")
|
||||||
|
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
|
||||||
|
with open(args.generated_input_save_path, "wb") as f:
|
||||||
|
pickle.dump(input_requests, f)
|
||||||
|
|
||||||
return input_requests
|
return input_requests
|
||||||
|
|
||||||
@@ -1331,6 +1345,16 @@ if __name__ == "__main__":
|
|||||||
default=256,
|
default=256,
|
||||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated-input-save-path",
|
||||||
|
type=str,
|
||||||
|
help="Path to save generated input data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated-input-path",
|
||||||
|
type=str,
|
||||||
|
help="Path to load previously generated input data",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_benchmark(args)
|
run_benchmark(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user