From 6e316588f87f5a428b0fc46adb505b28a189a96d Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 18 Aug 2025 09:26:09 -0700 Subject: [PATCH] [router] add reasoning parser base structure (#9310) Co-authored-by: Chang Su --- .pre-commit-config.yaml | 2 +- sgl-router/src/lib.rs | 1 + sgl-router/src/reasoning_parser/factory.rs | 232 +++++++++++ sgl-router/src/reasoning_parser/mod.rs | 7 + .../src/reasoning_parser/parsers/base.rs | 382 ++++++++++++++++++ .../src/reasoning_parser/parsers/mod.rs | 3 + sgl-router/src/reasoning_parser/traits.rs | 130 ++++++ 7 files changed, 756 insertions(+), 1 deletion(-) create mode 100644 sgl-router/src/reasoning_parser/factory.rs create mode 100644 sgl-router/src/reasoning_parser/mod.rs create mode 100644 sgl-router/src/reasoning_parser/parsers/base.rs create mode 100644 sgl-router/src/reasoning_parser/parsers/mod.rs create mode 100644 sgl-router/src/reasoning_parser/traits.rs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 346d8adf0..8f7455904 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: hooks: - id: codespell additional_dependencies: ['tomli'] - args: ['--toml', 'python/pyproject.toml', '-L', 'cann'] + args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi'] exclude: | (?x)^( test/srt/test_reasoning_parser\.py| diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 299dfdcfa..00c8e910d 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -7,6 +7,7 @@ pub mod metrics; pub mod middleware; pub mod openai_api_types; pub mod policies; +pub mod reasoning_parser; pub mod routers; pub mod server; pub mod service_discovery; diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs new file mode 100644 index 000000000..1ac2232b6 --- /dev/null +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -0,0 +1,232 @@ +// Factory and registry for creating model-specific reasoning parsers. + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use crate::reasoning_parser::parsers::BaseReasoningParser; +use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser}; + +/// Type alias for parser creator functions. +type ParserCreator = Arc Box + Send + Sync>; + +/// Registry for model-specific parsers. +#[derive(Clone)] +pub struct ParserRegistry { + parsers: Arc>>, + patterns: Arc>>, // (pattern, parser_name) +} + +impl ParserRegistry { + /// Create a new empty registry. + pub fn new() -> Self { + Self { + parsers: Arc::new(RwLock::new(HashMap::new())), + patterns: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Register a parser creator for a given parser type. + pub fn register_parser(&self, name: &str, creator: F) + where + F: Fn() -> Box + Send + Sync + 'static, + { + let mut parsers = self.parsers.write().unwrap(); + parsers.insert(name.to_string(), Arc::new(creator)); + } + + /// Register a model pattern to parser mapping. + /// Patterns are checked in order, first match wins. + pub fn register_pattern(&self, pattern: &str, parser_name: &str) { + let mut patterns = self.patterns.write().unwrap(); + patterns.push((pattern.to_string(), parser_name.to_string())); + } + + /// Get a parser by exact name. + pub fn get_parser(&self, name: &str) -> Option> { + let parsers = self.parsers.read().unwrap(); + parsers.get(name).map(|creator| creator()) + } + + /// Find a parser for a given model ID by pattern matching. + pub fn find_parser_for_model(&self, model_id: &str) -> Option> { + let patterns = self.patterns.read().unwrap(); + let model_lower = model_id.to_lowercase(); + + for (pattern, parser_name) in patterns.iter() { + if model_lower.contains(&pattern.to_lowercase()) { + return self.get_parser(parser_name); + } + } + None + } +} + +impl Default for ParserRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Factory for creating reasoning parsers based on model type. +pub struct ParserFactory { + registry: ParserRegistry, +} + +impl ParserFactory { + /// Create a new factory with default parsers registered. + pub fn new() -> Self { + let registry = ParserRegistry::new(); + + // Register base parser + registry.register_parser("base", || { + Box::new(BaseReasoningParser::new(ParserConfig::default())) + }); + + // Register DeepSeek-R1 parser + registry.register_parser("deepseek_r1", || { + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning: true, + stream_reasoning: true, + max_buffer_size: 65536, + }; + Box::new(BaseReasoningParser::new(config).with_model_type("deepseek_r1".to_string())) + }); + + // Register Qwen3 parser + registry.register_parser("qwen3", || { + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning: false, + stream_reasoning: true, + max_buffer_size: 65536, + }; + Box::new(BaseReasoningParser::new(config).with_model_type("qwen3".to_string())) + }); + + // Register Qwen3-thinking parser (forced reasoning) + registry.register_parser("qwen3_thinking", || { + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning: true, + stream_reasoning: true, + max_buffer_size: 65536, + }; + Box::new(BaseReasoningParser::new(config).with_model_type("qwen3_thinking".to_string())) + }); + + // Register Kimi parser with Unicode tokens + registry.register_parser("kimi", || { + let config = ParserConfig { + think_start_token: "◁think▷".to_string(), + think_end_token: "◁/think▷".to_string(), + force_reasoning: false, + stream_reasoning: true, + max_buffer_size: 65536, + }; + Box::new(BaseReasoningParser::new(config).with_model_type("kimi".to_string())) + }); + + // Register model patterns + registry.register_pattern("deepseek-r1", "deepseek_r1"); + registry.register_pattern("qwen3-thinking", "qwen3_thinking"); + registry.register_pattern("qwen-thinking", "qwen3_thinking"); + registry.register_pattern("qwen3", "qwen3"); + registry.register_pattern("qwen", "qwen3"); + registry.register_pattern("glm45", "qwen3"); // GLM45 uses same format as Qwen3 + registry.register_pattern("kimi", "kimi"); + registry.register_pattern("step3", "deepseek_r1"); // Step3 alias for DeepSeek-R1 + + Self { registry } + } + + /// Create a parser for the given model ID. + /// Returns a no-op parser if model is not recognized. + pub fn create(&self, model_id: &str) -> Result, ParseError> { + // First try to find by pattern + if let Some(parser) = self.registry.find_parser_for_model(model_id) { + return Ok(parser); + } + + // Fall back to no-op parser (base parser without reasoning detection) + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning: false, + stream_reasoning: true, + max_buffer_size: 65536, + }; + Ok(Box::new( + BaseReasoningParser::new(config).with_model_type("passthrough".to_string()), + )) + } + + /// Get the internal registry for custom registration. + pub fn registry(&self) -> &ParserRegistry { + &self.registry + } +} + +impl Default for ParserFactory { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_factory_creates_deepseek_r1() { + let factory = ParserFactory::new(); + let parser = factory.create("deepseek-r1-distill").unwrap(); + assert_eq!(parser.model_type(), "deepseek_r1"); + } + + #[test] + fn test_factory_creates_qwen3() { + let factory = ParserFactory::new(); + let parser = factory.create("qwen3-7b").unwrap(); + assert_eq!(parser.model_type(), "qwen3"); + } + + #[test] + fn test_factory_creates_kimi() { + let factory = ParserFactory::new(); + let parser = factory.create("kimi-chat").unwrap(); + assert_eq!(parser.model_type(), "kimi"); + } + + #[test] + fn test_factory_fallback_to_passthrough() { + let factory = ParserFactory::new(); + let parser = factory.create("unknown-model").unwrap(); + assert_eq!(parser.model_type(), "passthrough"); + } + + #[test] + fn test_case_insensitive_matching() { + let factory = ParserFactory::new(); + let parser1 = factory.create("DeepSeek-R1").unwrap(); + let parser2 = factory.create("QWEN3").unwrap(); + let parser3 = factory.create("Kimi").unwrap(); + + assert_eq!(parser1.model_type(), "deepseek_r1"); + assert_eq!(parser2.model_type(), "qwen3"); + assert_eq!(parser3.model_type(), "kimi"); + } + + #[test] + fn test_alias_models() { + let factory = ParserFactory::new(); + let step3 = factory.create("step3-model").unwrap(); + let glm45 = factory.create("glm45-v2").unwrap(); + + assert_eq!(step3.model_type(), "deepseek_r1"); + assert_eq!(glm45.model_type(), "qwen3"); + } +} diff --git a/sgl-router/src/reasoning_parser/mod.rs b/sgl-router/src/reasoning_parser/mod.rs new file mode 100644 index 000000000..fd975a7bf --- /dev/null +++ b/sgl-router/src/reasoning_parser/mod.rs @@ -0,0 +1,7 @@ +pub mod factory; +pub mod parsers; +pub mod traits; + +pub use factory::{ParserFactory, ParserRegistry}; +pub use parsers::BaseReasoningParser; +pub use traits::{ParseError, ParserResult, ReasoningParser}; diff --git a/sgl-router/src/reasoning_parser/parsers/base.rs b/sgl-router/src/reasoning_parser/parsers/base.rs new file mode 100644 index 000000000..78743b13d --- /dev/null +++ b/sgl-router/src/reasoning_parser/parsers/base.rs @@ -0,0 +1,382 @@ +// Base implementation of reasoning parser that handles common logic +// for detecting and extracting reasoning blocks from text. + +use crate::reasoning_parser::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser}; +use tracing as log; + +/// Base reasoning parser implementation. +/// +/// This parser handles the common logic for detecting reasoning blocks +/// delimited by start and end tokens (e.g., and ). +#[derive(Debug, Clone)] +pub struct BaseReasoningParser { + config: ParserConfig, + in_reasoning: bool, + buffer: String, + stripped_think_start: bool, + model_type: String, +} + +impl BaseReasoningParser { + /// Create a new BaseReasoningParser with the given configuration. + pub fn new(config: ParserConfig) -> Self { + let in_reasoning = config.force_reasoning; + Self { + config, + in_reasoning, + buffer: String::new(), + stripped_think_start: false, + model_type: "base".to_string(), + } + } + + /// Create with custom model type identifier. + pub fn with_model_type(mut self, model_type: String) -> Self { + self.model_type = model_type; + self + } + + /// Check if the current buffer is a prefix of one of the tokens. + fn is_partial_token(&self, text: &str) -> bool { + (self.config.think_start_token.starts_with(text) && self.config.think_start_token != text) + || (self.config.think_end_token.starts_with(text) + && self.config.think_end_token != text) + } +} + +impl ReasoningParser for BaseReasoningParser { + fn detect_and_parse_reasoning(&mut self, text: &str) -> Result { + log::debug!("detect_and_parse_reasoning called with text: {:?}", text); + + // Check input size against buffer limit + if text.len() > self.config.max_buffer_size { + return Err(ParseError::BufferOverflow(text.len())); + } + + let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token); + log::debug!("in_reasoning: {}", in_reasoning); + + if !in_reasoning { + log::debug!("No reasoning detected, returning normal text."); + return Ok(ParserResult::normal(text.to_string())); + } + + // The text is considered to be in a reasoning block. + let processed_text = text + .replace(&self.config.think_start_token, "") + .trim() + .to_string(); + log::debug!( + "Processed text after removing think_start_token: {:?}", + processed_text + ); + + if !processed_text.contains(&self.config.think_end_token) { + log::debug!( + "Reasoning truncated, think_end_token not found. Returning reasoning text." + ); + // Assume reasoning was truncated before end token + return Ok(ParserResult::reasoning(processed_text)); + } + + // Extract reasoning content + let splits: Vec<&str> = processed_text + .splitn(2, &self.config.think_end_token) + .collect(); + let reasoning_text = splits.first().unwrap_or(&"").to_string(); + let normal_text = splits + .get(1) + .map(|s| s.trim().to_string()) + .unwrap_or_default(); + + log::debug!("Extracted reasoning_text: {:?}", reasoning_text); + log::debug!("Extracted normal_text: {:?}", normal_text); + + Ok(ParserResult::new(normal_text, reasoning_text)) + } + + fn parse_reasoning_streaming_incremental( + &mut self, + text: &str, + ) -> Result { + // Check if adding this text would exceed buffer limit + if self.buffer.len() + text.len() > self.config.max_buffer_size { + return Err(ParseError::BufferOverflow(self.buffer.len() + text.len())); + } + + // Incrementally parse the streaming text + self.buffer.push_str(text); + let mut current_text = self.buffer.clone(); + + log::debug!( + "parse_reasoning_streaming_incremental called with text: {:?}", + text + ); + log::debug!("current buffer: {:?}", self.buffer); + log::debug!("current_text: {:?}", current_text); + log::debug!( + "in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}", + self.in_reasoning, + self.stripped_think_start, + self.config.stream_reasoning + ); + + // If the current text is a prefix of a token, keep buffering + if self.is_partial_token(¤t_text) { + return Ok(ParserResult::default()); + } + + // Strip start token if present + if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) { + current_text = current_text.replace(&self.config.think_start_token, ""); + self.buffer = current_text.clone(); + self.stripped_think_start = true; + self.in_reasoning = true; + } + + // Handle end of reasoning block + let think_end_idx = if self.in_reasoning { + current_text + .find(&self.config.think_end_token) + .unwrap_or(current_text.len()) + } else { + current_text.len() + }; + + if self.in_reasoning && think_end_idx < current_text.len() { + let reasoning_text = ¤t_text[..think_end_idx]; + self.buffer.clear(); + self.in_reasoning = false; + let start_idx = think_end_idx + self.config.think_end_token.len(); + let normal_text = if start_idx < current_text.len() { + ¤t_text[start_idx..] + } else { + "" + }; + return Ok(ParserResult::new( + normal_text.to_string(), + reasoning_text.trim().to_string(), + )); + } + + // Continue with reasoning content + if self.in_reasoning && self.config.stream_reasoning { + // Stream the content immediately + let reasoning_text = current_text; + self.buffer.clear(); + Ok(ParserResult::reasoning(reasoning_text)) + } else if !self.in_reasoning { + // If we're not in a reasoning block, return as normal text + // CRITICAL FIX: Return current_text (with buffer) not just text + // This prevents buffer loss when partial tokens are followed by normal text + let normal_text = current_text; + self.buffer.clear(); + Ok(ParserResult::normal(normal_text)) + } else { + // If we are in a reasoning block but no end token is found, buffer it + Ok(ParserResult::default()) + } + } + + fn reset(&mut self) { + self.in_reasoning = self.config.force_reasoning; + self.buffer.clear(); + self.stripped_think_start = false; + } + + fn model_type(&self) -> &str { + &self.model_type + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_parser(force_reasoning: bool, stream_reasoning: bool) -> BaseReasoningParser { + let config = ParserConfig { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning, + stream_reasoning, + max_buffer_size: 65536, + }; + BaseReasoningParser::new(config) + } + + #[test] + fn test_detect_and_parse_reasoning() { + let mut parser = create_test_parser(false, true); + let result = parser + .detect_and_parse_reasoning("with reasoning and more text.") + .unwrap(); + assert_eq!(result.normal_text, "and more text."); + assert_eq!(result.reasoning_text, "with reasoning"); + } + + #[test] + fn test_detect_and_parse_no_reasoning() { + let mut parser = create_test_parser(false, true); + let result = parser + .detect_and_parse_reasoning("This is a test without reasoning.") + .unwrap(); + assert_eq!(result.normal_text, "This is a test without reasoning."); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn test_detect_and_parse_truncated_reasoning() { + let mut parser = create_test_parser(false, true); + let result = parser + .detect_and_parse_reasoning("with truncated reasoning") + .unwrap(); + assert_eq!(result.normal_text, ""); + assert_eq!(result.reasoning_text, "with truncated reasoning"); + } + + #[test] + fn test_parse_streaming_partial_token() { + let mut parser = create_test_parser(false, true); + let result = parser + .parse_reasoning_streaming_incremental("with reasoning and more text.") + .unwrap(); + assert_eq!(result.normal_text, " and more text."); + assert_eq!(result.reasoning_text, "with reasoning"); + } + + #[test] + fn test_parse_streaming_no_end_token() { + let mut parser = create_test_parser(true, true); + let result = parser + .parse_reasoning_streaming_incremental("with reasoning") + .unwrap(); + assert_eq!(result.normal_text, ""); + assert_eq!(result.reasoning_text, "with reasoning"); + } + + #[test] + fn test_force_reasoning_mode() { + let mut parser = create_test_parser(true, true); + let result = parser + .detect_and_parse_reasoning("no think tags here") + .unwrap(); + assert_eq!(result.normal_text, ""); + assert_eq!(result.reasoning_text, "no think tags here"); + } + + #[test] + fn test_buffer_loss_bug_fix() { + // Critical test for buffer preservation + let mut parser = create_test_parser(false, true); + + // Step 1: Send partial end tag when not in reasoning mode + let result1 = parser.parse_reasoning_streaming_incremental("reasoning ") + .unwrap(); + assert_eq!(result1.normal_text, ""); + assert_eq!(result1.reasoning_text, "reasoning "); + + // Continue streaming reasoning + let result2 = parser + .parse_reasoning_streaming_incremental("content ") + .unwrap(); + assert_eq!(result2.normal_text, ""); + assert_eq!(result2.reasoning_text, "content "); + + // End reasoning block + let result3 = parser + .parse_reasoning_streaming_incremental("more normal") + .unwrap(); + assert_eq!(result3.normal_text, " normal"); + assert_eq!(result3.reasoning_text, "more"); + } + + #[test] + fn test_reset_state() { + let mut parser = create_test_parser(false, true); + + // Process some text + parser + .parse_reasoning_streaming_incremental("reasoning normal") + .unwrap(); + + // Reset and verify state + parser.reset(); + assert!(!parser.in_reasoning); + assert!(parser.buffer.is_empty()); + assert!(!parser.stripped_think_start); + } + + #[test] + fn test_buffer_overflow_detect_and_parse() { + let config = ParserConfig { + max_buffer_size: 10, // Set a very small buffer + ..Default::default() + }; + let mut parser = BaseReasoningParser::new(config); + + let large_text = "a".repeat(20); + let result = parser.detect_and_parse_reasoning(&large_text); + + assert!(result.is_err()); + match result { + Err(ParseError::BufferOverflow(size)) => { + assert_eq!(size, 20); + } + _ => panic!("Expected BufferOverflow error"), + } + } + + #[test] + fn test_buffer_overflow_streaming() { + let config = ParserConfig { + max_buffer_size: 10, // Set a very small buffer + ..Default::default() + }; + let mut parser = BaseReasoningParser::new(config); + + // Send a partial token that will be buffered + let result1 = parser.parse_reasoning_streaming_incremental(" { + assert_eq!(size, 21); // 4 + 17 + } + _ => panic!("Expected BufferOverflow error"), + } + } +} diff --git a/sgl-router/src/reasoning_parser/parsers/mod.rs b/sgl-router/src/reasoning_parser/parsers/mod.rs new file mode 100644 index 000000000..64a00f864 --- /dev/null +++ b/sgl-router/src/reasoning_parser/parsers/mod.rs @@ -0,0 +1,3 @@ +pub mod base; + +pub use base::BaseReasoningParser; diff --git a/sgl-router/src/reasoning_parser/traits.rs b/sgl-router/src/reasoning_parser/traits.rs new file mode 100644 index 000000000..672b76813 --- /dev/null +++ b/sgl-router/src/reasoning_parser/traits.rs @@ -0,0 +1,130 @@ +use std::fmt; + +/// Result of parsing text for reasoning content. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ParserResult { + /// The normal text outside of reasoning blocks. + pub normal_text: String, + + /// The extracted reasoning text from within reasoning blocks. + pub reasoning_text: String, +} + +impl ParserResult { + /// Create a new ParserResult with the given normal and reasoning text. + pub fn new(normal_text: String, reasoning_text: String) -> Self { + Self { + normal_text, + reasoning_text, + } + } + + /// Create a result with only normal text. + pub fn normal(text: String) -> Self { + Self { + normal_text: text, + reasoning_text: String::new(), + } + } + + /// Create a result with only reasoning text. + pub fn reasoning(text: String) -> Self { + Self { + normal_text: String::new(), + reasoning_text: text, + } + } + + /// Check if this result contains any text. + pub fn is_empty(&self) -> bool { + self.normal_text.is_empty() && self.reasoning_text.is_empty() + } +} + +/// Trait for parsing reasoning content from LLM outputs. +pub trait ReasoningParser: Send + Sync { + /// Detects and parses reasoning from the input text (one-time parsing). + /// + /// This method is used for non-streaming scenarios where the complete + /// text is available at once. + /// + /// Returns an error if the text exceeds buffer limits or contains invalid UTF-8. + fn detect_and_parse_reasoning(&mut self, text: &str) -> Result; + + /// Parses reasoning incrementally from streaming input. + /// + /// This method maintains internal state across calls to handle partial + /// tokens and chunk boundaries correctly. + /// + /// Returns an error if the buffer exceeds max_buffer_size. + fn parse_reasoning_streaming_incremental( + &mut self, + text: &str, + ) -> Result; + + /// Reset the parser state for reuse. + /// + /// This should clear any buffers and reset flags to initial state. + fn reset(&mut self); + + /// Get the model type this parser is designed for. + fn model_type(&self) -> &str; +} + +/// Error types for reasoning parsing operations. +#[derive(Debug, thiserror::Error)] +pub enum ParseError { + #[error("Invalid UTF-8 in stream: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + + #[error("Buffer overflow: {0} bytes exceeds maximum")] + BufferOverflow(usize), + + #[error("Unknown model type: {0}")] + UnknownModel(String), + + #[error("Parser configuration error: {0}")] + ConfigError(String), +} + +/// Configuration for parser behavior. +#[derive(Debug, Clone)] +pub struct ParserConfig { + /// The token that marks the start of reasoning content. + pub think_start_token: String, + + /// The token that marks the end of reasoning content. + pub think_end_token: String, + + /// Whether to force all text to be treated as reasoning. + pub force_reasoning: bool, + + /// Whether to stream reasoning content as it arrives. + pub stream_reasoning: bool, + + /// Maximum buffer size in bytes. + pub max_buffer_size: usize, +} + +impl Default for ParserConfig { + fn default() -> Self { + Self { + think_start_token: "".to_string(), + think_end_token: "".to_string(), + force_reasoning: false, + stream_reasoning: true, + max_buffer_size: 65536, // 64KB default + } + } +} + +impl fmt::Display for ParserResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ParserResult {{ normal: {} chars, reasoning: {} chars }}", + self.normal_text.len(), + self.reasoning_text.len() + ) + } +}