Tiny fix CI (#6611)
This commit is contained in:
@@ -493,8 +493,9 @@ def get_tokenizer(
|
||||
|
||||
|
||||
def get_dataset(args, tokenizer):
|
||||
tokenize_prompt = getattr(args, "tokenize_prompt", False)
|
||||
if args.dataset_name == "sharegpt":
|
||||
assert not args.tokenize_prompt
|
||||
assert not tokenize_prompt
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
@@ -513,10 +514,10 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=args.dataset_name == "random",
|
||||
return_text=not args.tokenize_prompt,
|
||||
return_text=not tokenize_prompt,
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
assert not args.tokenize_prompt
|
||||
assert not tokenize_prompt
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
num_groups=args.gsp_num_groups,
|
||||
prompts_per_group=args.gsp_prompts_per_group,
|
||||
@@ -527,7 +528,7 @@ def get_dataset(args, tokenizer):
|
||||
args=args,
|
||||
)
|
||||
elif args.dataset_name == "mmmu":
|
||||
assert not args.tokenize_prompt
|
||||
assert not tokenize_prompt
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user