Improve benchmark scripts and error message printing (#2922)

This commit is contained in:
Lianmin Zheng
2025-01-16 06:24:31 -08:00
committed by GitHub
parent 7596417732
commit 8f2c522aba
8 changed files with 125 additions and 70 deletions

View File

@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gen_num_groups,
prompts_per_group=args.gen_prompts_per_group,
system_prompt_len=args.gen_system_prompt_len,
question_len=args.gen_question_len,
output_len=args.gen_output_len,
num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group,
system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gsp_question_len,
output_len=args.gsp_output_len,
tokenizer=tokenizer,
)
else:
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
if prompt_len < 1 or output_len < 1:
# Prune too short sequences.
continue
if prompt_len > 1024 or (
prompt_len + output_len > 2048 and fixed_output_len is None
):
if context_len and prompt_len + output_len > context_len:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=None,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
@@ -1453,38 +1462,6 @@ if __name__ == "__main__":
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
"--gen-num-groups",
type=int,
default=64,
help="Number of system prompt groups for generated-shared-prefix dataset",
)
group.add_argument(
"--gen-prompts-per-group",
type=int,
default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
)
group.add_argument(
"--gen-system-prompt-len",
type=int,
default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
)
group.add_argument(
"--gen-question-len",
type=int,
default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset",
)
group.add_argument(
"--gen-output-len",
type=int,
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
"--profile",
action="store_true",
@@ -1497,5 +1474,37 @@ if __name__ == "__main__":
default=None,
help="The name of LoRA adapter",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
"--gsp-num-groups",
type=int,
default=64,
help="Number of system prompt groups for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-prompts-per-group",
type=int,
default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-system-prompt-len",
type=int,
default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-question-len",
type=int,
default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-output-len",
type=int,
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
args = parser.parse_args()
run_benchmark(args)