Reasoning parser (#4000)
Co-authored-by: Lucas Pickup <lupickup@microsoft.com>
This commit is contained in:
@@ -55,6 +55,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -75,6 +76,7 @@ from sglang.srt.openai_api.adapter import (
|
||||
v1_retrieve_file_content,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
@@ -460,6 +462,26 @@ async def parse_function_call_request(obj: ParseFunctionCallReq, request: Reques
|
||||
return ORJSONResponse(content=response_data, status_code=200)
|
||||
|
||||
|
||||
@app.post("/separate_reasoning")
|
||||
async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Request):
|
||||
"""
|
||||
A native API endpoint to separate reasoning from a text.
|
||||
"""
|
||||
# 1) Initialize the parser based on the request body
|
||||
parser = ReasoningParser(model_type=obj.reasoning_parser)
|
||||
|
||||
# 2) Call the non-stream parsing method (non-stream)
|
||||
reasoning_text, normal_text = parser.parse_non_stream(obj.text)
|
||||
|
||||
# 3) Organize the response content
|
||||
response_data = {
|
||||
"reasoning_text": reasoning_text,
|
||||
"text": normal_text,
|
||||
}
|
||||
|
||||
return ORJSONResponse(content=response_data, status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
|
||||
@@ -678,6 +678,12 @@ class ParseFunctionCallReq:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeparateReasoningReqInput:
|
||||
text: str # The text to parse.
|
||||
reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexGenerateReqInput:
|
||||
instances: List[dict]
|
||||
|
||||
@@ -72,6 +72,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1038,7 +1039,12 @@ def v1_chat_generate_request(
|
||||
|
||||
|
||||
def v1_chat_generate_response(
|
||||
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
||||
request,
|
||||
ret,
|
||||
to_file=False,
|
||||
cache_report=False,
|
||||
tool_call_parser=None,
|
||||
reasoning_parser=None,
|
||||
):
|
||||
choices = []
|
||||
|
||||
@@ -1092,9 +1098,26 @@ def v1_chat_generate_response(
|
||||
if isinstance(request, list):
|
||||
tool_choice = request[idx].tool_choice
|
||||
tools = request[idx].tools
|
||||
separate_reasoning = request[idx].separate_reasoning
|
||||
else:
|
||||
tool_choice = request.tool_choice
|
||||
tools = request.tools
|
||||
separate_reasoning = request.separate_reasoning
|
||||
|
||||
if reasoning_parser and separate_reasoning:
|
||||
try:
|
||||
parser = ReasoningParser(
|
||||
model_type=reasoning_parser, stream_reasoning=False
|
||||
)
|
||||
reasoning_text, text = parser.parse_non_stream(text)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Failed to parse reasoning related info to json format!",
|
||||
)
|
||||
else:
|
||||
reasoning_text = None
|
||||
|
||||
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
||||
if finish_reason == "stop":
|
||||
@@ -1124,8 +1147,9 @@ def v1_chat_generate_response(
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ret_item["text"] if tool_calls is None else None,
|
||||
"content": text if tool_calls is None else None,
|
||||
"tool_calls": tool_calls,
|
||||
"reasoning_content": reasoning_text,
|
||||
},
|
||||
"logprobs": choice_logprobs,
|
||||
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
||||
@@ -1140,8 +1164,9 @@ def v1_chat_generate_response(
|
||||
index=idx,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=ret_item["text"] if tool_calls is None else None,
|
||||
content=text if tool_calls is None else None,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_text,
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
@@ -1208,6 +1233,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
if adapted_request.stream:
|
||||
parser_dict = {}
|
||||
reasoning_parser_dict = {}
|
||||
|
||||
async def generate_stream_resp():
|
||||
is_firsts = {}
|
||||
@@ -1274,15 +1300,27 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
choice_logprobs = None
|
||||
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
finish_reason_type = (
|
||||
finish_reason["type"] if finish_reason else None
|
||||
)
|
||||
|
||||
if is_first:
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
if (
|
||||
tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
):
|
||||
delta = DeltaMessage(role="assistant", reasoning_content="")
|
||||
else:
|
||||
delta = DeltaMessage(role="assistant", content="")
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(role="assistant", content=""),
|
||||
delta=delta,
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
None
|
||||
if finish_reason_type and len(finish_reason_type) == 0
|
||||
else finish_reason_type
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
@@ -1302,6 +1340,41 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
delta = text[len(stream_buffer) :]
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
if (
|
||||
tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
):
|
||||
if index not in reasoning_parser_dict:
|
||||
reasoning_parser_dict[index] = ReasoningParser(
|
||||
tokenizer_manager.server_args.reasoning_parser,
|
||||
request.stream_reasoning,
|
||||
)
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
reasoning_text, delta = reasoning_parser.parse_stream_chunk(
|
||||
delta
|
||||
)
|
||||
if reasoning_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||
finish_reason=(
|
||||
None
|
||||
if finish_reason_type
|
||||
and len(finish_reason_type) == 0
|
||||
else finish_reason_type
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
if (delta and len(delta) == 0) or not delta:
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
continue
|
||||
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
@@ -1319,7 +1392,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
None
|
||||
if finish_reason_type
|
||||
and len(finish_reason_type) == 0
|
||||
else finish_reason_type
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
@@ -1388,7 +1464,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
None
|
||||
if finish_reason_type and len(finish_reason_type) == 0
|
||||
else finish_reason_type
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
@@ -1456,6 +1534,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
ret,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
reasoning_parser=tokenizer_manager.server_args.reasoning_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -336,6 +336,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
separate_reasoning: bool = True
|
||||
stream_reasoning: bool = True
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
@@ -356,6 +358,7 @@ class ToolCall(BaseModel):
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
@@ -379,6 +382,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
|
||||
154
python/sglang/srt/reasoning_parser.py
Normal file
154
python/sglang/srt/reasoning_parser.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import re
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
def __init__(self, normal_text: str = "", reasoning_text: str = ""):
|
||||
self.normal_text = normal_text
|
||||
self.reasoning_text = reasoning_text
|
||||
|
||||
|
||||
class BaseReasoningFormatDetector:
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
think_start_token: str,
|
||||
think_end_token: str,
|
||||
force_reasoning: bool = False,
|
||||
stream_reasoning: bool = True,
|
||||
):
|
||||
self.think_start_token = think_start_token
|
||||
self.think_end_token = think_end_token
|
||||
self._in_reasoning = force_reasoning
|
||||
self.stream_reasoning = stream_reasoning
|
||||
|
||||
self._buffer = ""
|
||||
self.stripped_think_start = False
|
||||
|
||||
def detect_and_parse(self, text: str) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses reasoning sections in the provided text.
|
||||
Returns both reasoning content and normal text separately.
|
||||
"""
|
||||
text = text.replace(self.think_start_token, "").strip()
|
||||
if self.think_end_token not in text:
|
||||
# Assume reasoning was truncated before `</think>` token
|
||||
return StreamingParseResult(reasoning_text=text)
|
||||
|
||||
# Extract reasoning content
|
||||
splits = text.split(self.think_end_token, maxsplit=1)
|
||||
reasoning_text = splits[0]
|
||||
text = splits[1].strip()
|
||||
|
||||
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
|
||||
|
||||
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for reasoning content.
|
||||
Handles partial reasoning tags and content.
|
||||
|
||||
If stream_reasoning is False:
|
||||
Accumulates reasoning content until the end tag is found
|
||||
If stream_reasoning is True:
|
||||
Streams reasoning content as it arrives
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
# Strip `<think>` token if present
|
||||
if not self.stripped_think_start and self.think_start_token in current_text:
|
||||
current_text = current_text.replace(self.think_start_token, "")
|
||||
self.stripped_think_start = True
|
||||
|
||||
# Handle end of reasoning block
|
||||
if self._in_reasoning and self.think_end_token in current_text:
|
||||
end_idx = current_text.find(self.think_end_token)
|
||||
|
||||
reasoning_text = current_text[:end_idx]
|
||||
|
||||
self._buffer = ""
|
||||
self._in_reasoning = False
|
||||
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
||||
|
||||
return StreamingParseResult(
|
||||
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
|
||||
)
|
||||
|
||||
# Continue with reasoning content
|
||||
if self._in_reasoning:
|
||||
if self.stream_reasoning:
|
||||
# Stream the content immediately
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(reasoning_text=current_text)
|
||||
else:
|
||||
return StreamingParseResult()
|
||||
|
||||
# If we're not in a reasoning block return as normal text
|
||||
if not self._in_reasoning:
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
return StreamingParseResult()
|
||||
|
||||
|
||||
class DeepSeekR1Detector(BaseReasoningFormatDetector):
|
||||
"""
|
||||
Detector for DeepSeek-R1 model.
|
||||
Assumes reasoning format:
|
||||
(<think>)*(.*)</think>
|
||||
Returns all the text before the </think> tag as `reasoning_text`
|
||||
and the rest of the text as `normal_text`.
|
||||
|
||||
Args:
|
||||
stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
|
||||
If True, streams reasoning content as it arrives.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_reasoning: bool = True):
|
||||
# DeepSeek-R1 is assumed to be reasoning until `</think>` token
|
||||
super().__init__(
|
||||
"<think>",
|
||||
"</think>",
|
||||
force_reasoning=True,
|
||||
stream_reasoning=stream_reasoning,
|
||||
)
|
||||
# https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599
|
||||
|
||||
|
||||
class ReasoningParser:
|
||||
"""
|
||||
Parser that handles both streaming and non-streaming scenarios for extracting
|
||||
reasoning content from model outputs.
|
||||
|
||||
Args:
|
||||
model_type (str): Type of model to parse reasoning from
|
||||
stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
|
||||
If True, streams reasoning content as it arrives.
|
||||
"""
|
||||
|
||||
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
||||
"deepseek-r1": DeepSeekR1Detector
|
||||
}
|
||||
|
||||
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
||||
if not model_type:
|
||||
raise ValueError("Model type must be specified")
|
||||
|
||||
detector_class = self.DetectorMap.get(model_type.lower())
|
||||
if not detector_class:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
self.detector = detector_class(stream_reasoning=stream_reasoning)
|
||||
|
||||
def parse_non_stream(self, full_text: str) -> Tuple[str, str]:
|
||||
"""Non-streaming call: one-time parsing"""
|
||||
ret = self.detector.detect_and_parse(full_text)
|
||||
return ret.reasoning_text, ret.normal_text
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]:
|
||||
"""Streaming call: incremental parsing"""
|
||||
ret = self.detector.parse_streaming_increment(chunk_text)
|
||||
return ret.reasoning_text, ret.normal_text
|
||||
@@ -23,6 +23,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.utils import (
|
||||
get_amdgpu_memory_capacity,
|
||||
get_hpu_memory_capacity,
|
||||
@@ -97,6 +98,7 @@ class ServerArgs:
|
||||
api_key: Optional[str] = None
|
||||
file_storage_path: str = "sglang_storage"
|
||||
enable_cache_report: bool = False
|
||||
reasoning_parser: Optional[str] = None
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -631,6 +633,13 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
choices=list(ReasoningParser.DetectorMap.keys()),
|
||||
default=ServerArgs.reasoning_parser,
|
||||
help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
|
||||
@@ -35,6 +35,7 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
|
||||
|
||||
Reference in New Issue
Block a user