Reasoning parser (#4000)
Co-authored-by: Lucas Pickup <lupickup@microsoft.com>
This commit is contained in:
154
python/sglang/srt/reasoning_parser.py
Normal file
154
python/sglang/srt/reasoning_parser.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import re
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
def __init__(self, normal_text: str = "", reasoning_text: str = ""):
|
||||
self.normal_text = normal_text
|
||||
self.reasoning_text = reasoning_text
|
||||
|
||||
|
||||
class BaseReasoningFormatDetector:
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
think_start_token: str,
|
||||
think_end_token: str,
|
||||
force_reasoning: bool = False,
|
||||
stream_reasoning: bool = True,
|
||||
):
|
||||
self.think_start_token = think_start_token
|
||||
self.think_end_token = think_end_token
|
||||
self._in_reasoning = force_reasoning
|
||||
self.stream_reasoning = stream_reasoning
|
||||
|
||||
self._buffer = ""
|
||||
self.stripped_think_start = False
|
||||
|
||||
def detect_and_parse(self, text: str) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses reasoning sections in the provided text.
|
||||
Returns both reasoning content and normal text separately.
|
||||
"""
|
||||
text = text.replace(self.think_start_token, "").strip()
|
||||
if self.think_end_token not in text:
|
||||
# Assume reasoning was truncated before `</think>` token
|
||||
return StreamingParseResult(reasoning_text=text)
|
||||
|
||||
# Extract reasoning content
|
||||
splits = text.split(self.think_end_token, maxsplit=1)
|
||||
reasoning_text = splits[0]
|
||||
text = splits[1].strip()
|
||||
|
||||
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
|
||||
|
||||
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for reasoning content.
|
||||
Handles partial reasoning tags and content.
|
||||
|
||||
If stream_reasoning is False:
|
||||
Accumulates reasoning content until the end tag is found
|
||||
If stream_reasoning is True:
|
||||
Streams reasoning content as it arrives
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
# Strip `<think>` token if present
|
||||
if not self.stripped_think_start and self.think_start_token in current_text:
|
||||
current_text = current_text.replace(self.think_start_token, "")
|
||||
self.stripped_think_start = True
|
||||
|
||||
# Handle end of reasoning block
|
||||
if self._in_reasoning and self.think_end_token in current_text:
|
||||
end_idx = current_text.find(self.think_end_token)
|
||||
|
||||
reasoning_text = current_text[:end_idx]
|
||||
|
||||
self._buffer = ""
|
||||
self._in_reasoning = False
|
||||
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
||||
|
||||
return StreamingParseResult(
|
||||
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
|
||||
)
|
||||
|
||||
# Continue with reasoning content
|
||||
if self._in_reasoning:
|
||||
if self.stream_reasoning:
|
||||
# Stream the content immediately
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(reasoning_text=current_text)
|
||||
else:
|
||||
return StreamingParseResult()
|
||||
|
||||
# If we're not in a reasoning block return as normal text
|
||||
if not self._in_reasoning:
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
return StreamingParseResult()
|
||||
|
||||
|
||||
class DeepSeekR1Detector(BaseReasoningFormatDetector):
|
||||
"""
|
||||
Detector for DeepSeek-R1 model.
|
||||
Assumes reasoning format:
|
||||
(<think>)*(.*)</think>
|
||||
Returns all the text before the </think> tag as `reasoning_text`
|
||||
and the rest of the text as `normal_text`.
|
||||
|
||||
Args:
|
||||
stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
|
||||
If True, streams reasoning content as it arrives.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_reasoning: bool = True):
|
||||
# DeepSeek-R1 is assumed to be reasoning until `</think>` token
|
||||
super().__init__(
|
||||
"<think>",
|
||||
"</think>",
|
||||
force_reasoning=True,
|
||||
stream_reasoning=stream_reasoning,
|
||||
)
|
||||
# https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599
|
||||
|
||||
|
||||
class ReasoningParser:
|
||||
"""
|
||||
Parser that handles both streaming and non-streaming scenarios for extracting
|
||||
reasoning content from model outputs.
|
||||
|
||||
Args:
|
||||
model_type (str): Type of model to parse reasoning from
|
||||
stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
|
||||
If True, streams reasoning content as it arrives.
|
||||
"""
|
||||
|
||||
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
||||
"deepseek-r1": DeepSeekR1Detector
|
||||
}
|
||||
|
||||
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
||||
if not model_type:
|
||||
raise ValueError("Model type must be specified")
|
||||
|
||||
detector_class = self.DetectorMap.get(model_type.lower())
|
||||
if not detector_class:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
self.detector = detector_class(stream_reasoning=stream_reasoning)
|
||||
|
||||
def parse_non_stream(self, full_text: str) -> Tuple[str, str]:
|
||||
"""Non-streaming call: one-time parsing"""
|
||||
ret = self.detector.detect_and_parse(full_text)
|
||||
return ret.reasoning_text, ret.normal_text
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]:
|
||||
"""Streaming call: incremental parsing"""
|
||||
ret = self.detector.parse_streaming_increment(chunk_text)
|
||||
return ret.reasoning_text, ret.normal_text
|
||||
Reference in New Issue
Block a user