diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs index ae0e66ca4..dad9c23b5 100644 --- a/sgl-router/src/tool_parser/mod.rs +++ b/sgl-router/src/tool_parser/mod.rs @@ -24,4 +24,6 @@ pub use traits::{PartialJsonParser, ToolParser}; pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall}; // Re-export parsers for convenience -pub use parsers::{JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser}; +pub use parsers::{ + DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, +}; diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs new file mode 100644 index 000000000..5e467bf2b --- /dev/null +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -0,0 +1,277 @@ +use async_trait::async_trait; +use regex::Regex; +use serde_json::Value; + +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + partial_json::PartialJson, + state::ParseState, + traits::ToolParser, + types::{FunctionCall, StreamResult, ToolCall}, +}; + +/// DeepSeek V3 format parser for tool calls +/// +/// Handles the DeepSeek V3 specific format that uses Unicode tokens: +/// `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|><|tool▁calls▁end|>` +/// +/// Features: +/// - Unicode token delimiters +/// - JSON arguments in code blocks +/// - Support for multiple sequential tool calls +pub struct DeepSeekParser { + /// Parser for handling incomplete JSON during streaming + partial_json: PartialJson, + /// Regex for extracting complete tool calls + tool_call_extractor: Regex, + /// Regex for extracting function details + func_detail_extractor: Regex, +} + +impl DeepSeekParser { + /// Create a new DeepSeek parser + pub fn new() -> Self { + // Use (?s) flag for DOTALL mode to handle newlines + let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>"; + let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); + + let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>"; + let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern"); + + Self { + partial_json: PartialJson::default(), + tool_call_extractor, + func_detail_extractor, + } + } + + /// Check if text contains DeepSeek tool markers + fn has_tool_markers(&self, text: &str) -> bool { + text.contains("<|tool▁calls▁begin|>") + } + + /// Extract all tool call blocks from text + fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> { + self.tool_call_extractor + .find_iter(text) + .map(|m| m.as_str()) + .collect() + } + + /// Parse a single tool call block + fn parse_tool_call(&self, block: &str) -> ToolParserResult> { + if let Some(captures) = self.func_detail_extractor.captures(block) { + // Get function type (should be "function") + let func_type = captures.get(1).map_or("", |m| m.as_str()); + if func_type != "function" { + return Ok(None); + } + + // Get function name + let func_name = captures.get(2).map_or("", |m| m.as_str()).trim(); + + // Get JSON arguments + let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim(); + + // Parse JSON arguments + match serde_json::from_str::(json_args) { + Ok(value) => { + // Create arguments object + let args = if value.is_object() { + value + } else { + // If not an object, wrap it + serde_json::json!({ "value": value }) + }; + + let arguments = serde_json::to_string(&args) + .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; + + // Generate ID + let id = format!("deepseek_call_{}", uuid::Uuid::new_v4()); + + Ok(Some(ToolCall { + id, + r#type: "function".to_string(), + function: FunctionCall { + name: func_name.to_string(), + arguments, + }, + })) + } + Err(_) => Ok(None), + } + } else { + Ok(None) + } + } +} + +impl Default for DeepSeekParser { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ToolParser for DeepSeekParser { + async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if text contains DeepSeek format + if !self.has_tool_markers(text) { + return Ok(vec![]); + } + + // Extract all tool call blocks + let tool_blocks = self.extract_tool_calls(text); + let mut tools = Vec::new(); + + for block in tool_blocks { + if let Some(tool) = self.parse_tool_call(block)? { + tools.push(tool); + } + } + + Ok(tools) + } + + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult { + state.buffer.push_str(chunk); + + // Check for tool markers + if !self.has_tool_markers(&state.buffer) { + // No markers found, return as incomplete + return Ok(StreamResult::Incomplete); + } + + // Look for start of tool calls + if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") { + // Look for individual tool call start + let search_from = start_pos + "<|tool▁calls▁begin|>".len(); + if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>") + { + let call_start_abs = search_from + call_start; + + // Look for the end of this tool call + let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len(); + if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>") + { + let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len(); + + // Extract and parse the complete tool call + let tool_call_text = &state.buffer[call_start_abs..call_end_abs]; + + if let Some(tool) = self.parse_tool_call(tool_call_text)? { + // Remove the processed part from buffer + state.buffer.drain(..call_end_abs); + + return Ok(StreamResult::ToolComplete(tool)); + } + } else { + // Tool call not complete yet, try to extract partial info + let partial = &state.buffer[search_end_from..]; + + // Try to extract function name + if let Some(sep_pos) = partial.find("<|tool▁sep|>") { + if let Some(_func_start) = partial[..sep_pos].rfind("function") { + // We have the function type marker + let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..]; + + // Look for function name (ends at newline before ```json) + if let Some(name_end) = after_sep.find("\n```json\n") { + let func_name = after_sep[..name_end].trim(); + + if !state.in_string { + state.in_string = true; // Mark name as sent + return Ok(StreamResult::ToolName { + index: 0, + name: func_name.to_string(), + }); + } + + // Try to extract partial arguments + let args_start = name_end + "\n```json\n".len(); + let partial_args = &after_sep[args_start..]; + + // Check if we can parse partial JSON + if !partial_args.is_empty() { + match self.partial_json.parse_value(partial_args) { + Ok((value, _consumed)) => { + let args_str = serde_json::to_string(&value) + .unwrap_or_else(|_| "{}".to_string()); + + return Ok(StreamResult::ToolArguments { + index: 0, + arguments: args_str, + }); + } + Err(_) => { + // Can't parse yet, keep buffering + } + } + } + } + } + } + } + } + } + + Ok(StreamResult::Incomplete) + } + + fn detect_format(&self, text: &str) -> bool { + self.has_tool_markers(text) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse_deepseek_single_tool() { + let parser = DeepSeekParser::new(); + let input = r#"Some text +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Tokyo", "units": "celsius"} +```<|tool▁call▁end|><|tool▁calls▁end|>More text"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Tokyo")); + } + + #[tokio::test] + async fn test_parse_deepseek_multiple_tools() { + let parser = DeepSeekParser::new(); + let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Tokyo"} +```<|tool▁call▁end|> +<|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Paris"} +```<|tool▁call▁end|><|tool▁calls▁end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "get_weather"); + assert!(result[0].function.arguments.contains("Tokyo")); + assert!(result[1].function.arguments.contains("Paris")); + } + + #[test] + fn test_detect_format() { + let parser = DeepSeekParser::new(); + assert!(parser.detect_format("<|tool▁calls▁begin|>")); + assert!(!parser.detect_format("plain text")); + assert!(!parser.detect_format("[TOOL_CALLS]")); + } +} diff --git a/sgl-router/src/tool_parser/parsers/mod.rs b/sgl-router/src/tool_parser/parsers/mod.rs index a5c2b0c28..1166b70d1 100644 --- a/sgl-router/src/tool_parser/parsers/mod.rs +++ b/sgl-router/src/tool_parser/parsers/mod.rs @@ -3,12 +3,16 @@ /// This module contains concrete parser implementations for various model-specific /// tool/function call formats. // Individual parser modules +pub mod deepseek_parser; pub mod json_parser; pub mod llama_parser; pub mod mistral_parser; pub mod pythonic_parser; pub mod qwen_parser; +// Re-export parser types for convenience +pub use deepseek_parser::DeepSeekParser; + pub use json_parser::JsonParser; pub use llama_parser::LlamaParser; pub use mistral_parser::MistralParser; diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs index 078d1c49d..e29c6c136 100644 --- a/sgl-router/src/tool_parser/registry.rs +++ b/sgl-router/src/tool_parser/registry.rs @@ -1,5 +1,5 @@ use crate::tool_parser::parsers::{ - JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, + DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, }; use crate::tool_parser::traits::ToolParser; use std::collections::HashMap; @@ -110,6 +110,9 @@ impl ParserRegistry { // Llama parser - <|python_tag|>{...} or plain JSON format self.register_parser("llama", Arc::new(LlamaParser::new())); + + // DeepSeek V3 parser - Unicode tokens with JSON blocks + self.register_parser("deepseek", Arc::new(DeepSeekParser::new())); } /// Register default model mappings @@ -141,7 +144,11 @@ impl ParserRegistry { self.map_model("llama-*", "json"); self.map_model("meta-llama-*", "json"); - // DeepSeek models - DeepSeek v3 would need custom parser, v2 uses pythonic + // DeepSeek models + // DeepSeek V3 uses custom Unicode token format + self.map_model("deepseek-v3*", "deepseek"); + self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek"); + // DeepSeek V2 uses pythonic format self.map_model("deepseek-*", "pythonic"); // Other models default to JSON diff --git a/sgl-router/tests/tool_parser_deepseek.rs b/sgl-router/tests/tool_parser_deepseek.rs new file mode 100644 index 000000000..45168c13e --- /dev/null +++ b/sgl-router/tests/tool_parser_deepseek.rs @@ -0,0 +1,183 @@ +//! DeepSeek V3 Parser Integration Tests + +use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, ToolParser}; + +#[tokio::test] +async fn test_deepseek_complete_parsing() { + let parser = DeepSeekParser::new(); + + // Test single tool call + let input = r#"Let me help you with that. +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Tokyo", "units": "celsius"} +```<|tool▁call▁end|><|tool▁calls▁end|> +The weather in Tokyo is..."#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + // Verify arguments + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["location"], "Tokyo"); + assert_eq!(args["units"], "celsius"); +} + +#[tokio::test] +async fn test_deepseek_multiple_tools() { + let parser = DeepSeekParser::new(); + + let input = r#"<|tool▁calls▁begin|> +<|tool▁call▁begin|>function<|tool▁sep|>search +```json +{"query": "rust programming"} +```<|tool▁call▁end|> +<|tool▁call▁begin|>function<|tool▁sep|>translate +```json +{"text": "Hello World", "to": "ja"} +```<|tool▁call▁end|> +<|tool▁calls▁end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "search"); + assert_eq!(result[1].function.name, "translate"); +} + +#[tokio::test] +async fn test_deepseek_streaming() { + let parser = DeepSeekParser::new(); + let mut state = ParseState::new(); + + // Simulate streaming chunks + let chunks = vec![ + "<|tool▁calls▁begin|><|tool▁call▁begin|>", + "function<|tool▁sep|>get_weather\n", + "```json\n", + r#"{"location": "#, + r#""Beijing", "#, + r#""units": "metric"}"#, + "\n```<|tool▁call▁end|><|tool▁calls▁end|>", + ]; + + let mut found_name = false; + let mut found_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + + match result { + StreamResult::ToolName { name, .. } => { + assert_eq!(name, "get_weather"); + found_name = true; + } + StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "get_weather"); + found_complete = true; + } + _ => {} + } + } + + assert!(found_name || found_complete); +} + +#[tokio::test] +async fn test_deepseek_nested_json() { + let parser = DeepSeekParser::new(); + + let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process +```json +{ + "data": { + "nested": { + "deep": [1, 2, 3] + } + } +} +```<|tool▁call▁end|><|tool▁calls▁end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "process"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert!(args["data"]["nested"]["deep"].is_array()); +} + +#[test] +fn test_deepseek_format_detection() { + let parser = DeepSeekParser::new(); + + // Should detect DeepSeek format + assert!(parser.detect_format("<|tool▁calls▁begin|>")); + assert!(parser.detect_format("text with <|tool▁calls▁begin|> marker")); + + // Should not detect other formats + assert!(!parser.detect_format("[TOOL_CALLS]")); + assert!(!parser.detect_format("")); + assert!(!parser.detect_format("plain text")); +} + +#[tokio::test] +async fn test_deepseek_malformed_json_handling() { + let parser = DeepSeekParser::new(); + + // Malformed JSON should be skipped + let input = r#"<|tool▁calls▁begin|> +<|tool▁call▁begin|>function<|tool▁sep|>broken +```json +{invalid json} +```<|tool▁call▁end|> +<|tool▁call▁begin|>function<|tool▁sep|>valid +```json +{"key": "value"} +```<|tool▁call▁end|> +<|tool▁calls▁end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + // Only the valid tool call should be parsed + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "valid"); +} + +#[tokio::test] +async fn test_normal_text_extraction() { + let parser = DeepSeekParser::new(); + + // Python extracts text before tool calls as normal_text + let input = r#"Let me help you with that. +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Tokyo"} +```<|tool▁call▁end|><|tool▁calls▁end|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + // TODO: Verify normal text extraction when parser returns it + // In Python: normal_text = "Let me help you with that." +} + +#[tokio::test] +async fn test_multiple_tool_calls() { + let parser = DeepSeekParser::new(); + + let input = r#"<|tool▁calls▁begin|> +<|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Tokyo"} +```<|tool▁call▁end|> +<|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Paris"} +```<|tool▁call▁end|> +<|tool▁calls▁end|><|end▁of▁sentence|>"#; + + let result = parser.parse_complete(input).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].function.name, "get_weather"); + assert_eq!(result[1].function.name, "get_weather"); +}