Support Thinking Budget (via custom_logit_processor for OpenAI API) [Fix #6572] (#11416)

Signed-off-by: ybyang <ybyang7@iflytek.com>
Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
ybyang
2025-10-21 16:27:56 +08:00
committed by GitHub
parent c1e1600373
commit dbb16bedd5
7 changed files with 239 additions and 1 deletions

View File

@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
custom_params: Optional[Dict] = None
custom_logit_processor: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# Custom logit processor for advanced sampling control
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
custom_params: Optional[Dict] = None
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
"ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens,
"logit_bias": self.logit_bias,
"custom_params": self.custom_params,
}
if self.response_format and self.response_format.type == "json_schema":

View File

@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
)
return adapted_request, request

View File

@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
)
return adapted_request, request
@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
"custom_params": request.custom_params,
}
# Handle response_format constraints

View File

@@ -1,12 +1,15 @@
import json
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import dill
import orjson
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
"""A logit processor that controls the length of thinking."""
THINKING_START_TOKEN_ID: int
THINKING_END_TOKEN_ID: int
NEW_LINE_TOKEN_ID: int
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
if custom_param_list is None or not custom_param_list:
return logits
for i, param_dict in enumerate(custom_param_list):
if param_dict is None:
continue
thinking_budget: int | None = param_dict.get("thinking_budget")
# Skip if thinking_budget is unset, or not an integer, or negative
if (
thinking_budget is None
or not isinstance(thinking_budget, int)
or thinking_budget < 0
):
continue
req: Req = param_dict.get("__req__")
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
# Check if out of thinking stage
if (
self.THINKING_START_TOKEN_ID not in cur_ids
or self.THINKING_END_TOKEN_ID in cur_ids
):
continue
# Find the index of the thinking start token
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
# Count the number of tokens after the thinking start token
num_tokens_after_start = len(cur_ids) - start_index - 1
if num_tokens_after_start < thinking_budget:
continue
# Ensure new line token before thinking end token
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
logits[i, :] = -float("inf")
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
continue
# Assign highest probability to the thinking end token
logits[i, :] = -float("inf")
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
return logits
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for Qwen3 models."""
THINKING_START_TOKEN_ID: int = 151667
THINKING_END_TOKEN_ID: int = 151668
NEW_LINE_TOKEN_ID: int = 198
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
THINKING_START_TOKEN_ID: int = 128798
THINKING_END_TOKEN_ID: int = 128799
NEW_LINE_TOKEN_ID: int = 201