diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 17cf850f6..955a3bfa5 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -36,6 +36,7 @@ class EvalArgs: profile: bool = False profile_number: int = 5 concurrency: int = 1 + max_new_tokens: int = 30 response_answer_regex: str = "(.*)" lora_path: Optional[str] = None @@ -94,6 +95,12 @@ class EvalArgs: default=EvalArgs.concurrency, help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.", ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=EvalArgs.max_new_tokens, + help="Maximum number of new tokens to generate per sample.", + ) parser.add_argument( "--response-answer-regex", type=str, @@ -234,7 +241,7 @@ def prepare_samples(eval_args: EvalArgs): def get_sampling_params(eval_args): - max_new_tokens = 30 + max_new_tokens = eval_args.max_new_tokens temperature = 0.001 extra_request_body = {}