From 95575aa76a0716aa018ac2359dc63b934f94ff51 Mon Sep 17 00:00:00 2001 From: Xihuai Wang Date: Tue, 4 Mar 2025 13:16:36 +0800 Subject: [PATCH] Reasoning parser (#4000) Co-authored-by: Lucas Pickup --- docs/backend/separate_reasoning.ipynb | 417 +++++++++++++++++++ docs/index.rst | 1 + docs/references/deepseek.md | 4 + python/sglang/srt/entrypoints/http_server.py | 22 + python/sglang/srt/managers/io_struct.py | 6 + python/sglang/srt/openai_api/adapter.py | 93 ++++- python/sglang/srt/openai_api/protocol.py | 4 + python/sglang/srt/reasoning_parser.py | 154 +++++++ python/sglang/srt/server_args.py | 9 + python/sglang/test/test_utils.py | 1 + test/srt/run_suite.py | 1 + test/srt/test_reasoning_content.py | 342 +++++++++++++++ 12 files changed, 1047 insertions(+), 7 deletions(-) create mode 100644 docs/backend/separate_reasoning.ipynb create mode 100644 python/sglang/srt/reasoning_parser.py create mode 100644 test/srt/test_reasoning_content.py diff --git a/docs/backend/separate_reasoning.ipynb b/docs/backend/separate_reasoning.ipynb new file mode 100644 index 000000000..d9a927c19 --- /dev/null +++ b/docs/backend/separate_reasoning.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reasoning Parser\n", + "\n", + "SGLang supports parsing reasoning content our from \"normal\" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\n", + "\n", + "## Supported Models\n", + "\n", + "Currently, SGLang supports the following reasoning models:\n", + "- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "### Launching the Server" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the `--reasoning-parser` option." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from openai import OpenAI\n", + "from sglang.test.test_utils import is_in_ci\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + "\n", + "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", + "\n", + "\n", + "server_process, port = launch_server_cmd(\n", + " \"python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OpenAI Compatible API\n", + "\n", + "Using the OpenAI compatible API, the contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:\n", + "\n", + "- `reasoning_content`: The content of the CoT.\n", + "- `content`: The content of the final answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize OpenAI-like client\n", + "client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n", + "model_name = client.models.list().data[0].id\n", + "\n", + "messages = [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is 1+3?\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Non-Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.6,\n", + " top_p=0.95,\n", + " stream=False, # Non-streaming\n", + " extra_body={\"separate_reasoning\": True},\n", + ")\n", + "print_highlight(\"==== Reasoning ====\")\n", + "print_highlight(response_non_stream.choices[0].message.reasoning_content)\n", + "\n", + "print_highlight(\"==== Text ====\")\n", + "print_highlight(response_non_stream.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.6,\n", + " top_p=0.95,\n", + " stream=True, # Non-streaming\n", + " extra_body={\"separate_reasoning\": True},\n", + ")\n", + "\n", + "reasoning_content = \"\"\n", + "content = \"\"\n", + "for chunk in response_stream:\n", + " if chunk.choices[0].delta.content:\n", + " content += chunk.choices[0].delta.content\n", + " if chunk.choices[0].delta.reasoning_content:\n", + " reasoning_content += chunk.choices[0].delta.reasoning_content\n", + "\n", + "print_highlight(\"==== Reasoning ====\")\n", + "print_highlight(reasoning_content)\n", + "\n", + "print_highlight(\"==== Text ====\")\n", + "print_highlight(content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.6,\n", + " top_p=0.95,\n", + " stream=True, # Non-streaming\n", + " extra_body={\"separate_reasoning\": True, \"stream_reasoning\": False},\n", + ")\n", + "\n", + "reasoning_content = \"\"\n", + "content = \"\"\n", + "for chunk in response_stream:\n", + " if chunk.choices[0].delta.content:\n", + " content += chunk.choices[0].delta.content\n", + " if chunk.choices[0].delta.reasoning_content:\n", + " reasoning_content = chunk.choices[0].delta.reasoning_content\n", + "\n", + "print_highlight(\"==== Reasoning ====\")\n", + "print_highlight(reasoning_content)\n", + "\n", + "print_highlight(\"==== Text ====\")\n", + "print_highlight(content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The reasoning separation is enable by default when specify . \n", + "**To disable it, set the `separate_reasoning` option to `False` in request.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.6,\n", + " top_p=0.95,\n", + " stream=False, # Non-streaming\n", + " extra_body={\"separate_reasoning\": False},\n", + ")\n", + "\n", + "print_highlight(\"==== Original Output ====\")\n", + "print_highlight(response_non_stream.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SGLang Native API " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", + "input = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + ")\n", + "\n", + "gen_url = f\"http://localhost:{port}/generate\"\n", + "gen_data = {\n", + " \"text\": input,\n", + " \"sampling_params\": {\n", + " \"skip_special_tokens\": False,\n", + " \"max_new_tokens\": 1024,\n", + " \"temperature\": 0.6,\n", + " \"top_p\": 0.95,\n", + " },\n", + "}\n", + "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", + "\n", + "print_highlight(\"==== Original Output ====\")\n", + "print_highlight(gen_response)\n", + "\n", + "parse_url = f\"http://localhost:{port}/separate_reasoning\"\n", + "separate_reasoning_data = {\n", + " \"text\": gen_response,\n", + " \"reasoning_parser\": \"deepseek-r1\",\n", + "}\n", + "separate_reasoning_response_json = requests.post(\n", + " parse_url, json=separate_reasoning_data\n", + ").json()\n", + "print_highlight(\"==== Reasoning ====\")\n", + "print_highlight(separate_reasoning_response_json[\"reasoning_text\"])\n", + "print_highlight(\"==== Text ====\")\n", + "print_highlight(separate_reasoning_response_json[\"text\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "from sglang.srt.reasoning_parser import ReasoningParser\n", + "from sglang.utils import print_highlight\n", + "\n", + "llm = sgl.Engine(model_path=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", + "input = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + ")\n", + "sampling_params = {\n", + " \"max_new_tokens\": 1024,\n", + " \"skip_special_tokens\": False,\n", + " \"temperature\": 0.6,\n", + " \"top_p\": 0.95,\n", + "}\n", + "result = llm.generate(prompt=input, sampling_params=sampling_params)\n", + "\n", + "generated_text = result[\"text\"] # Assume there is only one prompt\n", + "\n", + "print_highlight(\"==== Original Output ====\")\n", + "print_highlight(generated_text)\n", + "\n", + "parser = ReasoningParser(\"deepseek-r1\")\n", + "reasoning_text, text = parser.parse_non_stream(generated_text)\n", + "print_highlight(\"==== Reasoning ====\")\n", + "print_highlight(reasoning_text)\n", + "print_highlight(\"==== Text ====\")\n", + "print_highlight(text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Supporting New Reasoning Model Schemas\n", + "\n", + "For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningFormatDetector` in `python/sglang/srt/reasoning_parser.py` and specify the reasoning parser for new reasoning model schemas accordingly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "class DeepSeekR1Detector(BaseReasoningFormatDetector):\n", + " \"\"\"\n", + " Detector for DeepSeek-R1 model.\n", + " Assumes reasoning format:\n", + " ()*(.*)\n", + " Returns all the text before the tag as `reasoning_text`\n", + " and the rest of the text as `normal_text`.\n", + "\n", + " Args:\n", + " stream_reasoning (bool): If False, accumulates reasoning content until the end tag.\n", + " If True, streams reasoning content as it arrives.\n", + " \"\"\"\n", + "\n", + " def __init__(self, stream_reasoning: bool = False):\n", + " # DeepSeek-R1 is assumed to be reasoning until `` token\n", + " super().__init__(\"\", \"\", True, stream_reasoning=stream_reasoning)\n", + " # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599\n", + "\n", + "\n", + "class ReasoningParser:\n", + " \"\"\"\n", + " Parser that handles both streaming and non-streaming scenarios for extracting\n", + " reasoning content from model outputs.\n", + "\n", + " Args:\n", + " model_type (str): Type of model to parse reasoning from\n", + " stream_reasoning (bool): If Flase, accumulates reasoning content until complete.\n", + " If True, streams reasoning content as it arrives.\n", + " \"\"\"\n", + "\n", + " DetectorMap: Dict[str, BaseReasoningFormatDetector] = {\n", + " \"deepseek-r1\": DeepSeekR1Detector\n", + " }\n", + "\n", + " def __init__(self, model_type: str = None, stream_reasoning: bool = True):\n", + " if not model_type:\n", + " raise ValueError(\"Model type must be specified\")\n", + "\n", + " detector_class = self.DetectorMap.get(model_type.lower())\n", + " if not detector_class:\n", + " raise ValueError(f\"Unsupported model type: {model_type}\")\n", + "\n", + " self.detector = detector_class(stream_reasoning=stream_reasoning)\n", + "\n", + " def parse_non_stream(self, full_text: str) -> StreamingParseResult:\n", + " \"\"\"Non-streaming call: one-time parsing\"\"\"\n", + " ret = self.detector.detect_and_parse(full_text)\n", + " return ret.reasoning_text, ret.normal_text\n", + "\n", + " def parse_stream_chunk(self, chunk_text: str) -> StreamingParseResult:\n", + " \"\"\"Streaming call: incremental parsing\"\"\"\n", + " ret = self.detector.parse_streaming_increment(chunk_text)\n", + " return ret.reasoning_text, ret.normal_text\n", + "```" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/index.rst b/docs/index.rst index 8553b5e47..385a4de22 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,6 +37,7 @@ The core features include: backend/speculative_decoding.ipynb backend/structured_outputs.ipynb backend/function_calling.ipynb + backend/separate_reasoning.ipynb backend/custom_chat_template.md backend/quantization.md diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index ad180d1bd..058fd6ae0 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -131,6 +131,10 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o **Usage**: turn on by default for DeepSeek V3 models. +### Reasoning Content for DeepSeek R1 + +See [Separate Reasoning](https://docs.sglang.ai/backend/separate_reasoning.html). + ## FAQ 1. **Question**: What should I do if model loading takes too long and NCCL timeout occurs? diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 175b89c06..3f4fb6177 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -55,6 +55,7 @@ from sglang.srt.managers.io_struct import ( ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + SeparateReasoningReqInput, SetInternalStateReq, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, @@ -75,6 +76,7 @@ from sglang.srt.openai_api.adapter import ( v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, @@ -460,6 +462,26 @@ async def parse_function_call_request(obj: ParseFunctionCallReq, request: Reques return ORJSONResponse(content=response_data, status_code=200) +@app.post("/separate_reasoning") +async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Request): + """ + A native API endpoint to separate reasoning from a text. + """ + # 1) Initialize the parser based on the request body + parser = ReasoningParser(model_type=obj.reasoning_parser) + + # 2) Call the non-stream parsing method (non-stream) + reasoning_text, normal_text = parser.parse_non_stream(obj.text) + + # 3) Organize the response content + response_data = { + "reasoning_text": reasoning_text, + "text": normal_text, + } + + return ORJSONResponse(content=response_data, status_code=200) + + ##### OpenAI-compatible API endpoints ##### diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index ef185269d..f252aa159 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -678,6 +678,12 @@ class ParseFunctionCallReq: ) +@dataclass +class SeparateReasoningReqInput: + text: str # The text to parse. + reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1". + + @dataclass class VertexGenerateReqInput: instances: List[dict] diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index ca6291cc2..50464ba4b 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -72,6 +72,7 @@ from sglang.srt.openai_api.protocol import ( TopLogprob, UsageInfo, ) +from sglang.srt.reasoning_parser import ReasoningParser from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -1038,7 +1039,12 @@ def v1_chat_generate_request( def v1_chat_generate_response( - request, ret, to_file=False, cache_report=False, tool_call_parser=None + request, + ret, + to_file=False, + cache_report=False, + tool_call_parser=None, + reasoning_parser=None, ): choices = [] @@ -1092,9 +1098,26 @@ def v1_chat_generate_response( if isinstance(request, list): tool_choice = request[idx].tool_choice tools = request[idx].tools + separate_reasoning = request[idx].separate_reasoning else: tool_choice = request.tool_choice tools = request.tools + separate_reasoning = request.separate_reasoning + + if reasoning_parser and separate_reasoning: + try: + parser = ReasoningParser( + model_type=reasoning_parser, stream_reasoning=False + ) + reasoning_text, text = parser.parse_non_stream(text) + except Exception as e: + logger.error(f"Exception: {e}") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Failed to parse reasoning related info to json format!", + ) + else: + reasoning_text = None if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]): if finish_reason == "stop": @@ -1124,8 +1147,9 @@ def v1_chat_generate_response( "index": 0, "message": { "role": "assistant", - "content": ret_item["text"] if tool_calls is None else None, + "content": text if tool_calls is None else None, "tool_calls": tool_calls, + "reasoning_content": reasoning_text, }, "logprobs": choice_logprobs, "finish_reason": (finish_reason["type"] if finish_reason else ""), @@ -1140,8 +1164,9 @@ def v1_chat_generate_response( index=idx, message=ChatMessage( role="assistant", - content=ret_item["text"] if tool_calls is None else None, + content=text if tool_calls is None else None, tool_calls=tool_calls, + reasoning_content=reasoning_text, ), logprobs=choice_logprobs, finish_reason=(finish_reason["type"] if finish_reason else ""), @@ -1208,6 +1233,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: parser_dict = {} + reasoning_parser_dict = {} async def generate_stream_resp(): is_firsts = {} @@ -1274,15 +1300,27 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choice_logprobs = None finish_reason = content["meta_info"]["finish_reason"] + finish_reason_type = ( + finish_reason["type"] if finish_reason else None + ) if is_first: # First chunk with role is_first = False + if ( + tokenizer_manager.server_args.reasoning_parser + and request.separate_reasoning + ): + delta = DeltaMessage(role="assistant", reasoning_content="") + else: + delta = DeltaMessage(role="assistant", content="") choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(role="assistant", content=""), + delta=delta, finish_reason=( - finish_reason["type"] if finish_reason else "" + None + if finish_reason_type and len(finish_reason_type) == 0 + else finish_reason_type ), matched_stop=( finish_reason["matched"] @@ -1302,6 +1340,41 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): delta = text[len(stream_buffer) :] new_stream_buffer = stream_buffer + delta + if ( + tokenizer_manager.server_args.reasoning_parser + and request.separate_reasoning + ): + if index not in reasoning_parser_dict: + reasoning_parser_dict[index] = ReasoningParser( + tokenizer_manager.server_args.reasoning_parser, + request.stream_reasoning, + ) + reasoning_parser = reasoning_parser_dict[index] + reasoning_text, delta = reasoning_parser.parse_stream_chunk( + delta + ) + if reasoning_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(reasoning_content=reasoning_text), + finish_reason=( + None + if finish_reason_type + and len(finish_reason_type) == 0 + else finish_reason_type + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + if (delta and len(delta) == 0) or not delta: + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + continue + if request.tool_choice != "none" and request.tools: if index not in parser_dict: parser_dict[index] = FunctionCallParser( @@ -1319,7 +1392,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): index=index, delta=DeltaMessage(content=normal_text), finish_reason=( - finish_reason["type"] if finish_reason else "" + None + if finish_reason_type + and len(finish_reason_type) == 0 + else finish_reason_type ), ) chunk = ChatCompletionStreamResponse( @@ -1388,7 +1464,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): index=index, delta=DeltaMessage(content=delta), finish_reason=( - finish_reason["type"] if finish_reason else "" + None + if finish_reason_type and len(finish_reason_type) == 0 + else finish_reason_type ), matched_stop=( finish_reason["matched"] @@ -1456,6 +1534,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ret, cache_report=tokenizer_manager.server_args.enable_cache_report, tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + reasoning_parser=tokenizer_manager.server_args.reasoning_parser, ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 5f1ba431a..0c0aa0961 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -336,6 +336,8 @@ class ChatCompletionRequest(BaseModel): skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + separate_reasoning: bool = True + stream_reasoning: bool = True class FunctionResponse(BaseModel): @@ -356,6 +358,7 @@ class ToolCall(BaseModel): class ChatMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -379,6 +382,7 @@ class ChatCompletionResponse(BaseModel): class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py new file mode 100644 index 000000000..fe369896f --- /dev/null +++ b/python/sglang/srt/reasoning_parser.py @@ -0,0 +1,154 @@ +import re +from typing import Dict, Tuple + + +class StreamingParseResult: + """Result of streaming incremental parsing.""" + + def __init__(self, normal_text: str = "", reasoning_text: str = ""): + self.normal_text = normal_text + self.reasoning_text = reasoning_text + + +class BaseReasoningFormatDetector: + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__( + self, + think_start_token: str, + think_end_token: str, + force_reasoning: bool = False, + stream_reasoning: bool = True, + ): + self.think_start_token = think_start_token + self.think_end_token = think_end_token + self._in_reasoning = force_reasoning + self.stream_reasoning = stream_reasoning + + self._buffer = "" + self.stripped_think_start = False + + def detect_and_parse(self, text: str) -> StreamingParseResult: + """ + One-time parsing: Detects and parses reasoning sections in the provided text. + Returns both reasoning content and normal text separately. + """ + text = text.replace(self.think_start_token, "").strip() + if self.think_end_token not in text: + # Assume reasoning was truncated before `` token + return StreamingParseResult(reasoning_text=text) + + # Extract reasoning content + splits = text.split(self.think_end_token, maxsplit=1) + reasoning_text = splits[0] + text = splits[1].strip() + + return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text) + + def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: + """ + Streaming incremental parsing for reasoning content. + Handles partial reasoning tags and content. + + If stream_reasoning is False: + Accumulates reasoning content until the end tag is found + If stream_reasoning is True: + Streams reasoning content as it arrives + """ + self._buffer += new_text + current_text = self._buffer + + # Strip `` token if present + if not self.stripped_think_start and self.think_start_token in current_text: + current_text = current_text.replace(self.think_start_token, "") + self.stripped_think_start = True + + # Handle end of reasoning block + if self._in_reasoning and self.think_end_token in current_text: + end_idx = current_text.find(self.think_end_token) + + reasoning_text = current_text[:end_idx] + + self._buffer = "" + self._in_reasoning = False + normal_text = current_text[end_idx + len(self.think_end_token) :] + + return StreamingParseResult( + normal_text=normal_text, reasoning_text=reasoning_text.rstrip() + ) + + # Continue with reasoning content + if self._in_reasoning: + if self.stream_reasoning: + # Stream the content immediately + self._buffer = "" + return StreamingParseResult(reasoning_text=current_text) + else: + return StreamingParseResult() + + # If we're not in a reasoning block return as normal text + if not self._in_reasoning: + self._buffer = "" + return StreamingParseResult(normal_text=new_text) + + return StreamingParseResult() + + +class DeepSeekR1Detector(BaseReasoningFormatDetector): + """ + Detector for DeepSeek-R1 model. + Assumes reasoning format: + ()*(.*) + Returns all the text before the tag as `reasoning_text` + and the rest of the text as `normal_text`. + + Args: + stream_reasoning (bool): If False, accumulates reasoning content until the end tag. + If True, streams reasoning content as it arrives. + """ + + def __init__(self, stream_reasoning: bool = True): + # DeepSeek-R1 is assumed to be reasoning until `` token + super().__init__( + "", + "", + force_reasoning=True, + stream_reasoning=stream_reasoning, + ) + # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599 + + +class ReasoningParser: + """ + Parser that handles both streaming and non-streaming scenarios for extracting + reasoning content from model outputs. + + Args: + model_type (str): Type of model to parse reasoning from + stream_reasoning (bool): If Flase, accumulates reasoning content until complete. + If True, streams reasoning content as it arrives. + """ + + DetectorMap: Dict[str, BaseReasoningFormatDetector] = { + "deepseek-r1": DeepSeekR1Detector + } + + def __init__(self, model_type: str = None, stream_reasoning: bool = True): + if not model_type: + raise ValueError("Model type must be specified") + + detector_class = self.DetectorMap.get(model_type.lower()) + if not detector_class: + raise ValueError(f"Unsupported model type: {model_type}") + + self.detector = detector_class(stream_reasoning=stream_reasoning) + + def parse_non_stream(self, full_text: str) -> Tuple[str, str]: + """Non-streaming call: one-time parsing""" + ret = self.detector.detect_and_parse(full_text) + return ret.reasoning_text, ret.normal_text + + def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]: + """Streaming call: incremental parsing""" + ret = self.detector.parse_streaming_increment(chunk_text) + return ret.reasoning_text, ret.normal_text diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5833e1266..ac433b1eb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,6 +23,7 @@ from typing import List, Optional import torch from sglang.srt.hf_transformers_utils import check_gguf_file +from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_hpu_memory_capacity, @@ -97,6 +98,7 @@ class ServerArgs: api_key: Optional[str] = None file_storage_path: str = "sglang_storage" enable_cache_report: bool = False + reasoning_parser: Optional[str] = None # Data parallelism dp_size: int = 1 @@ -631,6 +633,13 @@ class ServerArgs: action="store_true", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", ) + parser.add_argument( + "--reasoning-parser", + type=str, + choices=list(ReasoningParser.DetectorMap.keys()), + default=ServerArgs.reasoning_parser, + help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.", + ) # Data parallelism parser.add_argument( diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index e8c4ce08c..5e42ba425 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -35,6 +35,7 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" +DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 326b96e33..8fb0a4314 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -59,6 +59,7 @@ suites = { "test_w8a8_quantization.py", "test_fp8_kernel.py", "test_block_int8.py", + "test_reasoning_content.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", diff --git a/test/srt/test_reasoning_content.py b/test/srt/test_reasoning_content.py new file mode 100644 index 000000000..f07dd6339 --- /dev/null +++ b/test/srt/test_reasoning_content.py @@ -0,0 +1,342 @@ +""" +Usage: +python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false +python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true +python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false +python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false +python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true +python3 -m unittest test_reasoning_content.TestReasoningContentStartup.test_nonstreaming +python3 -m unittest test_reasoning_content.TestReasoningContentStartup.test_streaming +""" + +import json +import unittest + +import openai +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_REASONING_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestReasoningContentAPI(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-1234" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--reasoning-parser", + "deepseek-r1", + ], + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_streaming_separate_reasoning_false(self): + # Test streaming with separate_reasoning=False, reasoning_content should be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + for chunk in response: + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + elif chunk.choices[0].delta.reasoning_content: + reasoning_content += chunk.choices[0].delta.reasoning_content + + assert len(reasoning_content) == 0 + assert len(content) > 0 + + def test_streaming_separate_reasoning_true(self): + # Test streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": True}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + for chunk in response: + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + elif chunk.choices[0].delta.reasoning_content: + reasoning_content += chunk.choices[0].delta.reasoning_content + + assert len(reasoning_content) > 0 + assert len(content) > 0 + + def test_streaming_separate_reasoning_true_stream_reasoning_false(self): + # Test streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": True, "stream_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + first_chunk = False + for chunk in response: + if chunk.choices[0].delta.reasoning_content: + reasoning_content = chunk.choices[0].delta.reasoning_content + first_chunk = True + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + if not first_chunk: + reasoning_content = chunk.choices[0].delta.reasoning_content + first_chunk = True + if not first_chunk: + assert ( + not chunk.choices[0].delta.reasoning_content + or len(chunk.choices[0].delta.reasoning_content) == 0 + ) + assert len(reasoning_content) > 0 + assert len(content) > 0 + + def test_nonstreaming_separate_reasoning_false(self): + # Test non-streaming with separate_reasoning=False, reasoning_content should be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "extra_body": {"separate_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + assert ( + not response.choices[0].message.reasoning_content + or len(response.choices[0].message.reasoning_content) == 0 + ) + assert len(response.choices[0].message.content) > 0 + + def test_nonstreaming_separate_reasoning_true(self): + # Test non-streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "extra_body": {"separate_reasoning": True}, + } + response = client.chat.completions.create(**payload) + + assert len(response.choices[0].message.reasoning_content) > 0 + assert len(response.choices[0].message.content) > 0 + + +class TestReasoningContentWithoutParser(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-1234" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[], # No reasoning parser + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_streaming_separate_reasoning_false(self): + # Test streaming with separate_reasoning=False, reasoning_content should be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + for chunk in response: + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + elif chunk.choices[0].delta.reasoning_content: + reasoning_content += chunk.choices[0].delta.reasoning_content + + assert len(reasoning_content) == 0 + assert len(content) > 0 + + def test_streaming_separate_reasoning_true(self): + # Test streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": True}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + for chunk in response: + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + elif chunk.choices[0].delta.reasoning_content: + reasoning_content += chunk.choices[0].delta.reasoning_content + + assert len(reasoning_content) == 0 + assert len(content) > 0 + + def test_streaming_separate_reasoning_true_stream_reasoning_false(self): + # Test streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "stream": True, + "extra_body": {"separate_reasoning": True, "stream_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + reasoning_content = "" + content = "" + first_chunk = False + for chunk in response: + if chunk.choices[0].delta.reasoning_content: + reasoning_content = chunk.choices[0].delta.reasoning_content + first_chunk = True + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + if not first_chunk: + reasoning_content = chunk.choices[0].delta.reasoning_content + first_chunk = True + if not first_chunk: + assert ( + not chunk.choices[0].delta.reasoning_content + or len(chunk.choices[0].delta.reasoning_content) == 0 + ) + assert not reasoning_content or len(reasoning_content) == 0 + assert len(content) > 0 + + def test_nonstreaming_separate_reasoning_false(self): + # Test non-streaming with separate_reasoning=False, reasoning_content should be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "extra_body": {"separate_reasoning": False}, + } + response = client.chat.completions.create(**payload) + + assert ( + not response.choices[0].message.reasoning_content + or len(response.choices[0].message.reasoning_content) == 0 + ) + assert len(response.choices[0].message.content) > 0 + + def test_nonstreaming_separate_reasoning_true(self): + # Test non-streaming with separate_reasoning=True, reasoning_content should not be empty + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": "What is 1+3?", + } + ], + "max_tokens": 100, + "extra_body": {"separate_reasoning": True}, + } + response = client.chat.completions.create(**payload) + + assert ( + not response.choices[0].message.reasoning_content + or len(response.choices[0].message.reasoning_content) == 0 + ) + assert len(response.choices[0].message.content) > 0 + + +if __name__ == "__main__": + unittest.main()