Add --thinking-mode to run_eval (#11189)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -16,13 +16,29 @@ from sglang.test.simple_eval_common import (
|
||||
)
|
||||
|
||||
|
||||
def get_thinking_kwargs(args):
|
||||
if args.thinking_mode in THINKING_MODE_CHOICES:
|
||||
thinking_param = (
|
||||
"thinking" if args.thinking_mode == "deepseek-v3" else "enable_thinking"
|
||||
)
|
||||
return {
|
||||
"chat_template_kwargs": {thinking_param: True},
|
||||
"separate_reasoning": True,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
|
||||
# Get thinking kwargs based on user's choice
|
||||
thinking_kwargs = get_thinking_kwargs(args)
|
||||
|
||||
sampler = ChatCompletionSampler(
|
||||
model=args.model,
|
||||
max_tokens=getattr(args, "max_tokens", 2048),
|
||||
base_url=base_url,
|
||||
temperature=getattr(args, "temperature", 0.0),
|
||||
reasoning_effort=getattr(args, "reasoning_effort", None),
|
||||
extra_body=thinking_kwargs,
|
||||
)
|
||||
|
||||
# Run eval
|
||||
@@ -136,6 +152,8 @@ def run_eval(args):
|
||||
return metrics
|
||||
|
||||
|
||||
THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -166,6 +184,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max-tokens", type=int, default=2048)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--reasoning-effort", type=str)
|
||||
parser.add_argument(
|
||||
"--thinking-mode",
|
||||
default=None,
|
||||
type=str,
|
||||
choices=THINKING_MODE_CHOICES,
|
||||
help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_eval(args)
|
||||
|
||||
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
|
||||
temperature: float = 0.0,
|
||||
reasoning_effort: Optional[str] = None,
|
||||
max_tokens: int = 2048,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
||||
|
||||
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.extra_body = extra_body
|
||||
self.image_format = "url"
|
||||
print(
|
||||
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
||||
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
|
||||
)
|
||||
|
||||
def _handle_image(
|
||||
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
||||
|
||||
Reference in New Issue
Block a user