Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
custom_params: Optional[Dict] = None
|
||||
custom_logit_processor: Optional[str] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# Custom logit processor for advanced sampling control
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||
custom_params: Optional[Dict] = None
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
"logit_bias": self.logit_bias,
|
||||
"custom_params": self.custom_params,
|
||||
}
|
||||
|
||||
if self.response_format and self.response_format.type == "json_schema":
|
||||
|
||||
@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
custom_logit_processor=request.custom_logit_processor,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
custom_logit_processor=request.custom_logit_processor,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
"custom_params": request.custom_params,
|
||||
}
|
||||
|
||||
# Handle response_format constraints
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import dill
|
||||
import orjson
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _cache_from_str(json_str: str):
|
||||
@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
|
||||
), f"{custom_param_list=}"
|
||||
logits[..., disallowed_token_ids] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int
|
||||
THINKING_END_TOKEN_ID: int
|
||||
NEW_LINE_TOKEN_ID: int
|
||||
|
||||
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
|
||||
if custom_param_list is None or not custom_param_list:
|
||||
return logits
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
if param_dict is None:
|
||||
continue
|
||||
|
||||
thinking_budget: int | None = param_dict.get("thinking_budget")
|
||||
|
||||
# Skip if thinking_budget is unset, or not an integer, or negative
|
||||
if (
|
||||
thinking_budget is None
|
||||
or not isinstance(thinking_budget, int)
|
||||
or thinking_budget < 0
|
||||
):
|
||||
continue
|
||||
req: Req = param_dict.get("__req__")
|
||||
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
|
||||
|
||||
# Check if out of thinking stage
|
||||
if (
|
||||
self.THINKING_START_TOKEN_ID not in cur_ids
|
||||
or self.THINKING_END_TOKEN_ID in cur_ids
|
||||
):
|
||||
continue
|
||||
|
||||
# Find the index of the thinking start token
|
||||
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
|
||||
|
||||
# Count the number of tokens after the thinking start token
|
||||
num_tokens_after_start = len(cur_ids) - start_index - 1
|
||||
|
||||
if num_tokens_after_start < thinking_budget:
|
||||
continue
|
||||
|
||||
# Ensure new line token before thinking end token
|
||||
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
|
||||
logits[i, :] = -float("inf")
|
||||
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
|
||||
continue
|
||||
|
||||
# Assign highest probability to the thinking end token
|
||||
logits[i, :] = -float("inf")
|
||||
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking for Qwen3 models."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int = 151667
|
||||
THINKING_END_TOKEN_ID: int = 151668
|
||||
NEW_LINE_TOKEN_ID: int = 198
|
||||
|
||||
|
||||
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int = 128798
|
||||
THINKING_END_TOKEN_ID: int = 128799
|
||||
NEW_LINE_TOKEN_ID: int = 201
|
||||
|
||||
Reference in New Issue
Block a user