Files
sglang/sgl-router/src/reasoning_parser/factory.rs

199 lines
6.6 KiB
Rust

// Factory and registry for creating model-specific reasoning parsers.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::reasoning_parser::parsers::{
BaseReasoningParser, DeepSeekR1Parser, KimiParser, Qwen3Parser, QwenThinkingParser,
};
use crate::reasoning_parser::traits::{ParseError, ParserConfig, ReasoningParser};
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
/// Registry for model-specific parsers.
#[derive(Clone)]
pub struct ParserRegistry {
parsers: Arc<RwLock<HashMap<String, ParserCreator>>>,
patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
}
impl ParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
parsers: Arc::new(RwLock::new(HashMap::new())),
patterns: Arc::new(RwLock::new(Vec::new())),
}
}
/// Register a parser creator for a given parser type.
pub fn register_parser<F>(&self, name: &str, creator: F)
where
F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
{
let mut parsers = self.parsers.write().unwrap();
parsers.insert(name.to_string(), Arc::new(creator));
}
/// Register a model pattern to parser mapping.
/// Patterns are checked in order, first match wins.
pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
let mut patterns = self.patterns.write().unwrap();
patterns.push((pattern.to_string(), parser_name.to_string()));
}
/// Get a parser by exact name.
pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
let parsers = self.parsers.read().unwrap();
parsers.get(name).map(|creator| creator())
}
/// Find a parser for a given model ID by pattern matching.
pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) {
return self.get_parser(parser_name);
}
}
None
}
}
impl Default for ParserRegistry {
fn default() -> Self {
Self::new()
}
}
/// Factory for creating reasoning parsers based on model type.
pub struct ParserFactory {
registry: ParserRegistry,
}
impl ParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ParserRegistry::new();
// Register base parser
registry.register_parser("base", || {
Box::new(BaseReasoningParser::new(ParserConfig::default()))
});
// Register DeepSeek-R1 parser (starts with in_reasoning=true)
registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
// Register Qwen3 parser (starts with in_reasoning=false)
registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
// Register Qwen3-thinking parser (starts with in_reasoning=true)
registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
// Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
registry.register_parser("kimi", || Box::new(KimiParser::new()));
// Register model patterns
registry.register_pattern("deepseek-r1", "deepseek_r1");
registry.register_pattern("qwen3-thinking", "qwen3_thinking");
registry.register_pattern("qwen-thinking", "qwen3_thinking");
registry.register_pattern("qwen3", "qwen3");
registry.register_pattern("qwen", "qwen3");
registry.register_pattern("glm45", "qwen3"); // GLM45 uses same format as Qwen3
registry.register_pattern("kimi", "kimi");
registry.register_pattern("step3", "deepseek_r1"); // Step3 alias for DeepSeek-R1
Self { registry }
}
/// Create a parser for the given model ID.
/// Returns a no-op parser if model is not recognized.
pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
// First try to find by pattern
if let Some(parser) = self.registry.find_parser_for_model(model_id) {
return Ok(parser);
}
// Fall back to no-op parser (base parser without reasoning detection)
let config = ParserConfig {
think_start_token: "".to_string(),
think_end_token: "".to_string(),
stream_reasoning: true,
max_buffer_size: 65536,
initial_in_reasoning: false,
};
Ok(Box::new(
BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
))
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ParserRegistry {
&self.registry
}
}
impl Default for ParserFactory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1");
}
#[test]
fn test_factory_creates_qwen3() {
let factory = ParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3");
}
#[test]
fn test_factory_creates_kimi() {
let factory = ParserFactory::new();
let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi");
}
#[test]
fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new();
let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_case_insensitive_matching() {
let factory = ParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap();
assert_eq!(parser1.model_type(), "deepseek_r1");
assert_eq!(parser2.model_type(), "qwen3");
assert_eq!(parser3.model_type(), "kimi");
}
#[test]
fn test_alias_models() {
let factory = ParserFactory::new();
let step3 = factory.create("step3-model").unwrap();
let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(step3.model_type(), "deepseek_r1");
assert_eq!(glm45.model_type(), "qwen3");
}
}