Improve benchmark scripts and error message printing (#2922)
This commit is contained in:
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
num_groups=args.gen_num_groups,
|
||||
prompts_per_group=args.gen_prompts_per_group,
|
||||
system_prompt_len=args.gen_system_prompt_len,
|
||||
question_len=args.gen_question_len,
|
||||
output_len=args.gen_output_len,
|
||||
num_groups=args.gsp_num_groups,
|
||||
prompts_per_group=args.gsp_prompts_per_group,
|
||||
system_prompt_len=args.gsp_system_prompt_len,
|
||||
question_len=args.gsp_question_len,
|
||||
output_len=args.gsp_output_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
|
||||
output_len = (
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
|
||||
if prompt_len < 1 or output_len < 1:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or (
|
||||
prompt_len + output_len > 2048 and fixed_output_len is None
|
||||
):
|
||||
|
||||
if context_len and prompt_len + output_len > context_len:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
||||
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
|
||||
|
||||
# Create a unique cache filename based on the generation parameters
|
||||
cache_key = (
|
||||
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
||||
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
||||
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
|
||||
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
|
||||
f"{tokenizer.__class__.__name__}.pkl"
|
||||
)
|
||||
return cache_dir / cache_key
|
||||
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-context-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
@@ -1453,38 +1462,6 @@ if __name__ == "__main__":
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
"--gen-num-groups",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of system prompt groups for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-prompts-per-group",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-system-prompt-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-question-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
@@ -1497,5 +1474,37 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
"--gsp-num-groups",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of system prompt groups for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-prompts-per-group",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-system-prompt-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-question-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
Reference in New Issue
Block a user