diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 63bcbc9eb..b751174fc 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -48,6 +48,7 @@ metrics = "0.24.2" metrics-exporter-prometheus = "0.17.0" uuid = { version = "1.10", features = ["v4", "serde"] } thiserror = "2.0.12" +regex = "1.10" url = "2.5.4" tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" diff --git a/sgl-router/benches/tokenizer_benchmark.rs b/sgl-router/benches/tokenizer_benchmark.rs index c9f82f607..a40abcc4e 100644 --- a/sgl-router/benches/tokenizer_benchmark.rs +++ b/sgl-router/benches/tokenizer_benchmark.rs @@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) { let tokenizer_clone = tokenizer.clone(); // Get token count once - let token_count = tokenizer.encode(prompt).unwrap().token_ids().len(); + let encoding = tokenizer.encode(prompt).unwrap(); + let token_count = encoding.token_ids().len(); // Track if metrics have been printed for this test case let printed = Arc::new(AtomicBool::new(false)); @@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) { let batch_sizes = vec![1, 8, 16, 32, 64, 128]; let prompt = MEDIUM_PROMPT; let prompt_len = prompt.len(); - let token_count = tokenizer.encode(prompt).unwrap().token_ids().len(); + let encoding = tokenizer.encode(prompt).unwrap(); + let token_count = encoding.token_ids().len(); let mut group = c.benchmark_group("batch_encode"); @@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) { ); let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(10); - let tokens = tokenizer.encode(&test_text).unwrap().token_ids(); + let encoding = tokenizer.encode(&test_text).unwrap(); + let tokens = encoding.token_ids(); let num_tokens = tokens.len(); let mut group = c.benchmark_group("decode_performance"); @@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) { group.bench_function("direct_decode", |b| { let printed = printed_direct.clone(); let tokenizer = tokenizer.clone(); - let tokens = tokens.clone(); b.iter_custom(|iters| { let start = Instant::now(); for _ in 0..iters { - black_box(tokenizer.decode(&tokens, false).unwrap()); + black_box(tokenizer.decode(tokens, false).unwrap()); } let duration = start.elapsed(); @@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) { group.bench_function("decode_stream", |b| { let printed = printed_stream.clone(); let tokenizer = tokenizer.clone(); - let tokens = tokens.clone(); b.iter_custom(|iters| { let start = Instant::now(); for _ in 0..iters { let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false); let mut output = String::new(); - for token in &tokens { + for token in tokens { if let Some(text) = decoder.step(*token).unwrap() { output.push_str(&text); } @@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) { group.bench_function("sequence_decode", |b| { let printed = printed_seq.clone(); let tokenizer = tokenizer.clone(); - let tokens = tokens.clone(); b.iter_custom(|iters| { let start = Instant::now(); for _ in 0..iters { let mut sequence = Sequence::new(tokenizer.clone()); let mut output = String::new(); - for token in &tokens { + for token in tokens { let text = sequence.append_token(*token).unwrap(); output.push_str(&text); } @@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) { ); let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(1000); - let all_tokens = tokenizer.encode(&sample_text).unwrap().token_ids(); + let encoding = tokenizer.encode(&sample_text).unwrap(); + let all_tokens = encoding.token_ids(); let mut group = c.benchmark_group("streaming_100k"); group.measurement_time(Duration::from_secs(1)); @@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) { group.bench_function("decode_stream_100k", |b| { let printed = printed_stream.clone(); let tokenizer = tokenizer.clone(); - let tokens = all_tokens.clone(); b.iter_custom(|_iters| { let start = Instant::now(); @@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) { let mut output = String::new(); let mut tokens_processed = 0u64; - for token in tokens.iter().cycle() { + for token in all_tokens.iter().cycle() { if start.elapsed() >= Duration::from_millis(500) { break; } @@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) { group.bench_function("sequence_100k", |b| { let printed = printed_seq.clone(); let tokenizer = tokenizer.clone(); - let tokens = all_tokens.clone(); b.iter_custom(|_iters| { let start = Instant::now(); @@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) { let mut output = String::new(); let mut tokens_processed = 0u64; - for token in tokens.iter().cycle() { + for token in all_tokens.iter().cycle() { if start.elapsed() >= Duration::from_millis(500) { break; } @@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) { let tokens_per_sequence = 10_000; let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(100); - let token_batch = tokenizer.encode(&sample_text).unwrap().token_ids(); + let encoding = tokenizer.encode(&sample_text).unwrap(); + let token_batch: Vec = encoding.token_ids().to_vec(); let mut group = c.benchmark_group("concurrent_streaming"); group.measurement_time(Duration::from_secs(2)); @@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) { .with_stop_token(2); let sample_text = "Hello world! This is a test. ### Stop here. Continue after.".repeat(100); - let tokens = tokenizer.encode(&sample_text).unwrap().token_ids(); + let encoding = tokenizer.encode(&sample_text).unwrap(); + let tokens = encoding.token_ids(); let mut group = c.benchmark_group("stop_sequences"); @@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) { group.bench_function("no_stops", |b| { let printed_clone = printed_no_stop.clone(); let tokenizer = tokenizer.clone(); - let tokens = tokens.clone(); b.iter_custom(|iters| { let start = Instant::now(); @@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) { StopSequenceConfig::default(), false, ); - for token in &tokens { + for token in tokens { let _ = decoder.process_token(*token).unwrap(); total_tokens += 1; } @@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) { group.bench_function("with_stops", |b| { let printed_clone = printed_with_stops.clone(); let tokenizer = tokenizer.clone(); - let tokens = tokens.clone(); let config = config.clone(); b.iter_custom(|iters| { @@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) { StopSequenceDecoder::new(tokenizer.clone(), config.clone(), false); let mut sequence_tokens = 0u64; - for token in &tokens { + for token in tokens { let result = decoder.process_token(*token).unwrap(); sequence_tokens += 1; @@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) { // Generate tokens for decoding let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(100); - let test_tokens = tokenizer.encode(&test_text).unwrap().token_ids(); + let encoding = tokenizer.encode(&test_text).unwrap(); + let test_tokens: Vec = encoding.token_ids().to_vec(); let mut group = c.benchmark_group("multithreaded_decode"); group.measurement_time(Duration::from_secs(2)); @@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) { b.iter_custom(|iters| { let start = Instant::now(); for _ in 0..iters { - let _ = black_box(encoding.token_ids_ref()); + let _ = black_box(encoding.token_ids()); } let duration = start.elapsed(); diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 4644ea257..40d8ee162 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -14,6 +14,7 @@ pub mod routers; pub mod server; pub mod service_discovery; pub mod tokenizer; +pub mod tool_parser; pub mod tree; use crate::metrics::PrometheusConfig; diff --git a/sgl-router/src/tool_parser/errors.rs b/sgl-router/src/tool_parser/errors.rs new file mode 100644 index 000000000..30129a596 --- /dev/null +++ b/sgl-router/src/tool_parser/errors.rs @@ -0,0 +1,32 @@ +use thiserror::Error; + +/// Result type for tool parser operations +pub type ToolParserResult = Result; + +/// Errors that can occur during tool parsing +#[derive(Debug, Error)] +pub enum ToolParserError { + #[error("Parsing failed: {0}")] + ParsingFailed(String), + + #[error("Model not supported: {0}")] + ModelNotSupported(String), + + #[error("Parse depth exceeded: max {0}")] + DepthExceeded(usize), + + #[error("Invalid JSON: {0}")] + JsonError(#[from] serde_json::Error), + + #[error("Regex error: {0}")] + RegexError(#[from] regex::Error), + + #[error("Incomplete tool call")] + Incomplete, + + #[error("Invalid tool name: {0}")] + InvalidToolName(String), + + #[error("Token not found: {0}")] + TokenNotFound(String), +} diff --git a/sgl-router/src/tool_parser/mod.rs b/sgl-router/src/tool_parser/mod.rs new file mode 100644 index 000000000..9545e4de0 --- /dev/null +++ b/sgl-router/src/tool_parser/mod.rs @@ -0,0 +1,20 @@ +/// Tool parser module for handling function/tool calls in model outputs +/// +/// This module provides infrastructure for parsing tool calls from various model formats. +/// Phase 1 focuses on core infrastructure: types, traits, registry, and partial JSON parsing. +pub mod errors; +pub mod partial_json; +pub mod registry; +pub mod state; +pub mod traits; +pub mod types; + +#[cfg(test)] +mod tests; + +// Re-export commonly used types +pub use errors::{ToolParserError, ToolParserResult}; +pub use registry::ParserRegistry; +pub use state::{ParsePhase, ParseState}; +pub use traits::{PartialJsonParser, ToolParser}; +pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall}; diff --git a/sgl-router/src/tool_parser/partial_json.rs b/sgl-router/src/tool_parser/partial_json.rs new file mode 100644 index 000000000..4a4504fe0 --- /dev/null +++ b/sgl-router/src/tool_parser/partial_json.rs @@ -0,0 +1,527 @@ +use crate::tool_parser::{ + errors::{ToolParserError, ToolParserResult}, + traits::PartialJsonParser, +}; +use serde_json::{Map, Value}; + +/// Parser for incomplete JSON +pub struct PartialJson { + /// Maximum depth for nested structures + max_depth: usize, + /// Whether to allow incomplete values + allow_incomplete: bool, +} + +impl PartialJson { + /// Create a new partial JSON parser + pub fn new(max_depth: usize, allow_incomplete: bool) -> Self { + Self { + max_depth, + allow_incomplete, + } + } + + /// Parse potentially incomplete JSON, returning parsed value and consumed bytes + pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> { + let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete); + let value = parser.parse_value(0)?; + Ok((value, parser.position)) + } +} + +impl Default for PartialJson { + fn default() -> Self { + Self::new(32, true) + } +} + +impl PartialJsonParser for PartialJson { + fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> { + self.parse_value(input) + } + + fn is_complete(&self, input: &str) -> bool { + // Try to parse as complete JSON + serde_json::from_str::(input).is_ok() + } + + fn max_depth(&self) -> usize { + self.max_depth + } +} + +/// Internal parser state +struct Parser<'a> { + chars: std::iter::Peekable>, + position: usize, + max_depth: usize, + allow_incomplete: bool, +} + +impl<'a> Parser<'a> { + fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self { + Self { + chars: input.chars().peekable(), + position: 0, + max_depth, + allow_incomplete, + } + } + + fn peek(&mut self) -> Option { + self.chars.peek().copied() + } + + fn advance(&mut self) { + if self.chars.next().is_some() { + self.position += 1; + } + } + + fn skip_whitespace(&mut self) { + while let Some(ch) = self.peek() { + if ch.is_whitespace() { + self.advance(); + } else { + break; + } + } + } + + fn parse_value(&mut self, depth: usize) -> ToolParserResult { + if depth > self.max_depth { + return Err(ToolParserError::DepthExceeded(self.max_depth)); + } + + self.skip_whitespace(); + + match self.peek() { + Some('{') => self.parse_object(depth + 1), + Some('[') => self.parse_array(depth + 1), + Some('"') => self.parse_string(), + Some('t') | Some('f') => self.parse_bool(), + Some('n') => self.parse_null(), + Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(), + _ => { + if self.allow_incomplete { + Ok(Value::Null) + } else { + Err(ToolParserError::ParsingFailed( + "Unexpected character".into(), + )) + } + } + } + } + + fn parse_object(&mut self, depth: usize) -> ToolParserResult { + if depth > self.max_depth { + return Err(ToolParserError::DepthExceeded(self.max_depth)); + } + + let mut object = Map::new(); + + // Consume '{' + self.advance(); + self.skip_whitespace(); + + // Check for empty object + if self.peek() == Some('}') { + self.advance(); + return Ok(Value::Object(object)); + } + + loop { + // Parse key + let key = match self.parse_string() { + Ok(Value::String(s)) => s, + Err(_) if self.allow_incomplete => { + // Incomplete object + return Ok(Value::Object(object)); + } + Err(e) => return Err(e), + _ => return Err(ToolParserError::ParsingFailed("Expected string key".into())), + }; + + self.skip_whitespace(); + + // Expect ':' + if self.peek() != Some(':') { + if self.allow_incomplete { + // Add null value for incomplete pair + object.insert(key, Value::Null); + return Ok(Value::Object(object)); + } + return Err(ToolParserError::ParsingFailed("Expected ':'".into())); + } + self.advance(); + self.skip_whitespace(); + + // Parse value (keep same depth - we already incremented in parse_object) + let value = match self.parse_value(depth) { + Ok(v) => v, + Err(_) if self.allow_incomplete => { + // Add null for incomplete value + object.insert(key, Value::Null); + return Ok(Value::Object(object)); + } + Err(e) => return Err(e), + }; + + object.insert(key, value); + self.skip_whitespace(); + + match self.peek() { + Some(',') => { + self.advance(); + self.skip_whitespace(); + // Check for trailing comma + if self.peek() == Some('}') { + self.advance(); + return Ok(Value::Object(object)); + } + } + Some('}') => { + self.advance(); + return Ok(Value::Object(object)); + } + None if self.allow_incomplete => { + return Ok(Value::Object(object)); + } + _ => { + if self.allow_incomplete { + return Ok(Value::Object(object)); + } + return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into())); + } + } + } + } + + fn parse_array(&mut self, depth: usize) -> ToolParserResult { + if depth > self.max_depth { + return Err(ToolParserError::DepthExceeded(self.max_depth)); + } + + let mut array = Vec::new(); + + // Consume '[' + self.advance(); + self.skip_whitespace(); + + // Check for empty array + if self.peek() == Some(']') { + self.advance(); + return Ok(Value::Array(array)); + } + + loop { + // Parse value (keep same depth - we already incremented in parse_object) + let value = match self.parse_value(depth) { + Ok(v) => v, + Err(_) if self.allow_incomplete => { + return Ok(Value::Array(array)); + } + Err(e) => return Err(e), + }; + + array.push(value); + self.skip_whitespace(); + + match self.peek() { + Some(',') => { + self.advance(); + self.skip_whitespace(); + // Check for trailing comma + if self.peek() == Some(']') { + self.advance(); + return Ok(Value::Array(array)); + } + } + Some(']') => { + self.advance(); + return Ok(Value::Array(array)); + } + None if self.allow_incomplete => { + return Ok(Value::Array(array)); + } + _ => { + if self.allow_incomplete { + return Ok(Value::Array(array)); + } + return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into())); + } + } + } + } + + fn parse_string(&mut self) -> ToolParserResult { + if self.peek() != Some('"') { + return Err(ToolParserError::ParsingFailed("Expected '\"'".into())); + } + + // Consume opening quote + self.advance(); + + let mut string = String::new(); + let mut escaped = false; + + while let Some(ch) = self.peek() { + if escaped { + // Handle escape sequences + let escaped_char = match ch { + '"' | '\\' | '/' => ch, + 'b' => '\u{0008}', + 'f' => '\u{000C}', + 'n' => '\n', + 'r' => '\r', + 't' => '\t', + 'u' => { + // Unicode escape + self.advance(); + let hex = self.parse_unicode_escape()?; + string.push(hex); + escaped = false; + continue; + } + _ => ch, // Invalid escape, but be lenient + }; + string.push(escaped_char); + escaped = false; + } else if ch == '\\' { + escaped = true; + } else if ch == '"' { + // End of string + self.advance(); + return Ok(Value::String(string)); + } else { + string.push(ch); + } + self.advance(); + } + + // Incomplete string + if self.allow_incomplete { + Ok(Value::String(string)) + } else { + Err(ToolParserError::ParsingFailed("Unterminated string".into())) + } + } + + fn parse_unicode_escape(&mut self) -> ToolParserResult { + let mut hex = String::new(); + for _ in 0..4 { + if let Some(ch) = self.peek() { + if ch.is_ascii_hexdigit() { + hex.push(ch); + self.advance(); + } else { + break; + } + } else { + break; + } + } + + if hex.len() == 4 { + u32::from_str_radix(&hex, 16) + .ok() + .and_then(char::from_u32) + .ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into())) + } else if self.allow_incomplete { + Ok('\u{FFFD}') // Replacement character + } else { + Err(ToolParserError::ParsingFailed( + "Incomplete unicode escape".into(), + )) + } + } + + fn parse_number(&mut self) -> ToolParserResult { + let mut number = String::new(); + + // Handle negative sign + if self.peek() == Some('-') { + number.push('-'); + self.advance(); + } + + // Parse integer part + if self.peek() == Some('0') { + number.push('0'); + self.advance(); + } else { + while let Some(ch) = self.peek() { + if ch.is_ascii_digit() { + number.push(ch); + self.advance(); + } else { + break; + } + } + } + + // Parse decimal part + if self.peek() == Some('.') { + number.push('.'); + self.advance(); + + while let Some(ch) = self.peek() { + if ch.is_ascii_digit() { + number.push(ch); + self.advance(); + } else { + break; + } + } + } + + // Parse exponent + if let Some(ch) = self.peek() { + if ch == 'e' || ch == 'E' { + number.push(ch); + self.advance(); + + if let Some(sign) = self.peek() { + if sign == '+' || sign == '-' { + number.push(sign); + self.advance(); + } + } + + while let Some(ch) = self.peek() { + if ch.is_ascii_digit() { + number.push(ch); + self.advance(); + } else { + break; + } + } + } + } + + // Try to parse as integer first, then as float + if let Ok(n) = number.parse::() { + Ok(Value::Number(serde_json::Number::from(n))) + } else if let Ok(n) = number.parse::() { + Ok(Value::Number( + serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)), + )) + } else if self.allow_incomplete { + Ok(Value::Number(serde_json::Number::from(0))) + } else { + Err(ToolParserError::ParsingFailed("Invalid number".into())) + } + } + + fn parse_bool(&mut self) -> ToolParserResult { + let mut word = String::new(); + + // Peek at upcoming characters to validate it looks like a boolean + let mut temp_chars = self.chars.clone(); + while let Some(&ch) = temp_chars.peek() { + if ch.is_alphabetic() && word.len() < 5 { + // "false" is 5 chars + word.push(ch); + temp_chars.next(); + } else { + break; + } + } + + // Check if it's a valid boolean prefix + let is_valid = word == "true" + || word == "false" + || (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word))); + + if !is_valid { + return Err(ToolParserError::ParsingFailed("Invalid boolean".into())); + } + + // Now actually consume the characters + word.clear(); + while let Some(ch) = self.peek() { + if ch.is_alphabetic() { + word.push(ch); + self.advance(); + } else { + break; + } + } + + match word.as_str() { + "true" => Ok(Value::Bool(true)), + "false" => Ok(Value::Bool(false)), + partial if self.allow_incomplete => { + if "true".starts_with(partial) { + Ok(Value::Bool(true)) + } else if "false".starts_with(partial) { + Ok(Value::Bool(false)) + } else { + Err(ToolParserError::ParsingFailed("Invalid boolean".into())) + } + } + _ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())), + } + } + + fn parse_null(&mut self) -> ToolParserResult { + let mut word = String::new(); + + // Peek at upcoming characters to validate it looks like "null" + let mut temp_chars = self.chars.clone(); + while let Some(&ch) = temp_chars.peek() { + if ch.is_alphabetic() && word.len() < 4 { + // "null" is 4 chars + word.push(ch); + temp_chars.next(); + } else { + break; + } + } + + // Check if it's a valid null prefix + let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word)); + + if !is_valid { + return Err(ToolParserError::ParsingFailed("Invalid null".into())); + } + + // Now actually consume the characters + word.clear(); + while let Some(ch) = self.peek() { + if ch.is_alphabetic() { + word.push(ch); + self.advance(); + } else { + break; + } + } + + if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) { + Ok(Value::Null) + } else { + Err(ToolParserError::ParsingFailed("Invalid null".into())) + } + } +} + +/// Utility function to check if a string contains complete JSON +pub fn is_complete_json(input: &str) -> bool { + serde_json::from_str::(input).is_ok() +} + +/// Utility function to find common prefix between two strings +pub fn find_common_prefix(s1: &str, s2: &str) -> usize { + s1.chars() + .zip(s2.chars()) + .take_while(|(a, b)| a == b) + .count() +} + +/// Utility function to compute diff between old and new strings +pub fn compute_diff(old: &str, new: &str) -> String { + let common_len = find_common_prefix(old, new); + // Convert character count to byte offset + new.chars().skip(common_len).collect() +} diff --git a/sgl-router/src/tool_parser/registry.rs b/sgl-router/src/tool_parser/registry.rs new file mode 100644 index 000000000..aca354e7c --- /dev/null +++ b/sgl-router/src/tool_parser/registry.rs @@ -0,0 +1,119 @@ +use crate::tool_parser::traits::ToolParser; +use std::collections::HashMap; +use std::sync::Arc; + +/// Registry for tool parsers and model mappings +pub struct ParserRegistry { + /// Map of parser name to parser instance + parsers: HashMap>, + /// Map of model name/pattern to parser name + model_mapping: HashMap, + /// Default parser to use when no match found + default_parser: String, +} + +impl ParserRegistry { + /// Create a new parser registry with default mappings + pub fn new() -> Self { + let mut registry = Self { + parsers: HashMap::new(), + model_mapping: HashMap::new(), + default_parser: "json".to_string(), + }; + + // Register default model mappings + registry.register_default_mappings(); + + registry + } + + /// Register a parser + pub fn register_parser(&mut self, name: impl Into, parser: Arc) { + self.parsers.insert(name.into(), parser); + } + + /// Map a model name/pattern to a parser + pub fn map_model(&mut self, model: impl Into, parser: impl Into) { + self.model_mapping.insert(model.into(), parser.into()); + } + + /// Get parser for a specific model + pub fn get_parser(&self, model: &str) -> Option> { + // Try exact match first + if let Some(parser_name) = self.model_mapping.get(model) { + if let Some(parser) = self.parsers.get(parser_name) { + return Some(parser.clone()); + } + } + + // Try prefix matching (e.g., "gpt-4" matches "gpt-*") + for (pattern, parser_name) in &self.model_mapping { + if pattern.ends_with('*') { + let prefix = &pattern[..pattern.len() - 1]; + if model.starts_with(prefix) { + if let Some(parser) = self.parsers.get(parser_name) { + return Some(parser.clone()); + } + } + } + } + + // Fall back to default parser if it exists + self.parsers.get(&self.default_parser).cloned() + } + + /// List all registered parsers + pub fn list_parsers(&self) -> Vec<&str> { + self.parsers.keys().map(|s| s.as_str()).collect() + } + + /// List all model mappings + pub fn list_mappings(&self) -> Vec<(&str, &str)> { + self.model_mapping + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect() + } + + /// Register default model mappings + fn register_default_mappings(&mut self) { + // OpenAI models + self.map_model("gpt-4*", "json"); + self.map_model("gpt-3.5*", "json"); + self.map_model("gpt-4o*", "json"); + + // Anthropic models + self.map_model("claude-*", "json"); + + // Mistral models + self.map_model("mistral-*", "mistral"); + self.map_model("mixtral-*", "mistral"); + + // Qwen models + self.map_model("qwen*", "qwen"); + + // Llama models + self.map_model("llama-*", "llama"); + self.map_model("meta-llama-*", "llama"); + + // Other models default to JSON + self.map_model("gemini-*", "json"); + self.map_model("palm-*", "json"); + } + + /// Set the default parser + pub fn set_default_parser(&mut self, name: impl Into) { + self.default_parser = name.into(); + } + + /// Check if a parser is registered + pub fn has_parser(&self, name: &str) -> bool { + self.parsers.contains_key(name) + } +} + +impl Default for ParserRegistry { + fn default() -> Self { + Self::new() + } +} diff --git a/sgl-router/src/tool_parser/state.rs b/sgl-router/src/tool_parser/state.rs new file mode 100644 index 000000000..096a9352f --- /dev/null +++ b/sgl-router/src/tool_parser/state.rs @@ -0,0 +1,181 @@ +use crate::tool_parser::types::{PartialToolCall, ToolCall}; + +/// Current phase of parsing +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParsePhase { + /// Looking for start of tool call + Searching, + /// Parsing function name + InName, + /// Parsing function arguments + InArguments, + /// Tool call complete + Complete, +} + +/// State for streaming parser +#[derive(Debug, Clone)] +pub struct ParseState { + /// Buffer for accumulating input + pub buffer: String, + /// Position of last consumed character + pub consumed: usize, + /// Current partial tool being parsed + pub partial_tool: Option, + /// Completed tool calls + pub completed_tools: Vec, + /// Current parsing phase + pub phase: ParsePhase, + /// Bracket/brace depth for JSON parsing + pub bracket_depth: i32, + /// Whether currently inside a string literal + pub in_string: bool, + /// Whether next character should be escaped + pub escape_next: bool, + /// Current tool index (for streaming) + pub tool_index: usize, +} + +impl ParseState { + /// Create a new parse state + pub fn new() -> Self { + Self { + buffer: String::new(), + consumed: 0, + partial_tool: None, + completed_tools: Vec::new(), + phase: ParsePhase::Searching, + bracket_depth: 0, + in_string: false, + escape_next: false, + tool_index: 0, + } + } + + /// Reset state for parsing next tool + pub fn reset(&mut self) { + self.partial_tool = None; + self.phase = ParsePhase::Searching; + self.bracket_depth = 0; + self.in_string = false; + self.escape_next = false; + } + + /// Process a single character for JSON parsing + pub fn process_char(&mut self, ch: char) { + // Handle escape sequences + if self.escape_next { + self.escape_next = false; + self.buffer.push(ch); + return; + } + + if ch == '\\' && self.in_string { + self.escape_next = true; + self.buffer.push(ch); + return; + } + + // Track string boundaries + if ch == '"' && !self.escape_next { + self.in_string = !self.in_string; + } + + // Track bracket depth for JSON + if !self.in_string { + match ch { + '{' | '[' => { + self.bracket_depth += 1; + } + '}' | ']' => { + self.bracket_depth -= 1; + if self.bracket_depth == 0 && self.partial_tool.is_some() { + // Complete tool call found + self.phase = ParsePhase::Complete; + } + } + _ => {} + } + } + + self.buffer.push(ch); + } + + /// Check if we have a complete JSON object/array + pub fn has_complete_json(&self) -> bool { + self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty() + } + + /// Extract content from buffer starting at position + pub fn extract_from(&self, start: usize) -> &str { + if start >= self.buffer.len() { + return ""; + } + + // Find the nearest character boundary at or after start + let mut safe_start = start; + while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) { + safe_start += 1; + } + + if safe_start < self.buffer.len() { + &self.buffer[safe_start..] + } else { + "" + } + } + + /// Mark content as consumed up to position + pub fn consume_to(&mut self, position: usize) { + if position > self.consumed { + self.consumed = position; + } + } + + /// Get unconsumed content + pub fn unconsumed(&self) -> &str { + if self.consumed >= self.buffer.len() { + return ""; + } + + // Find the nearest character boundary at or after consumed + let mut safe_consumed = self.consumed; + while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) { + safe_consumed += 1; + } + + if safe_consumed < self.buffer.len() { + &self.buffer[safe_consumed..] + } else { + "" + } + } + + /// Clear consumed content from buffer + pub fn clear_consumed(&mut self) { + if self.consumed > 0 { + // Find the nearest character boundary at or before consumed + let mut safe_consumed = self.consumed; + while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) { + safe_consumed -= 1; + } + + if safe_consumed > 0 { + self.buffer.drain(..safe_consumed); + self.consumed = self.consumed.saturating_sub(safe_consumed); + } + } + } + + /// Add completed tool + pub fn add_completed_tool(&mut self, tool: ToolCall) { + self.completed_tools.push(tool); + self.tool_index += 1; + } +} + +impl Default for ParseState { + fn default() -> Self { + Self::new() + } +} diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs new file mode 100644 index 000000000..e13c614a0 --- /dev/null +++ b/sgl-router/src/tool_parser/tests.rs @@ -0,0 +1,249 @@ +use super::*; +use crate::tool_parser::partial_json::{ + compute_diff, find_common_prefix, is_complete_json, PartialJson, +}; + +#[test] +fn test_parse_state_new() { + let state = ParseState::new(); + assert_eq!(state.phase, ParsePhase::Searching); + assert_eq!(state.buffer, ""); + assert_eq!(state.consumed, 0); + assert_eq!(state.bracket_depth, 0); + assert!(!state.in_string); + assert!(!state.escape_next); +} + +#[test] +fn test_parse_state_process_char() { + let mut state = ParseState::new(); + + // Test bracket tracking + state.process_char('{'); + assert_eq!(state.bracket_depth, 1); + + state.process_char('}'); + assert_eq!(state.bracket_depth, 0); + + // Test string tracking + state.process_char('"'); + assert!(state.in_string); + + state.process_char('"'); + assert!(!state.in_string); + + // Test escape handling + state.process_char('"'); + state.process_char('\\'); + assert!(state.escape_next); + + state.process_char('"'); + assert!(!state.escape_next); + assert!(state.in_string); // Still in string because quote was escaped +} + +#[test] +fn test_token_config() { + let config = TokenConfig { + start_tokens: vec!["".to_string(), "[".to_string()], + end_tokens: vec!["".to_string(), "]".to_string()], + separator: ", ".to_string(), + }; + + let pairs: Vec<_> = config.iter_pairs().collect(); + assert_eq!(pairs.len(), 2); + assert_eq!(pairs[0], ("", "")); + assert_eq!(pairs[1], ("[", "]")); +} + +#[test] +fn test_parser_registry() { + let registry = ParserRegistry::new(); + + // Test has default mappings + assert!(!registry.list_mappings().is_empty()); + + // Test model pattern matching + let mappings = registry.list_mappings(); + let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt")); + assert!(has_gpt); +} + +#[test] +fn test_parser_registry_pattern_matching() { + let mut registry = ParserRegistry::new(); + + // Test that model mappings work by checking the list + registry.map_model("test-model", "json"); + + // Verify through list_mappings + let mappings = registry.list_mappings(); + let has_test = mappings + .iter() + .any(|(m, p)| *m == "test-model" && *p == "json"); + assert!(has_test); +} + +#[test] +fn test_tool_call_serialization() { + let tool_call = ToolCall { + id: "call-123".to_string(), + r#type: "function".to_string(), + function: FunctionCall { + name: "search".to_string(), + arguments: r#"{"query": "rust programming"}"#.to_string(), + }, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains("call-123")); + assert!(json.contains("search")); + assert!(json.contains("rust programming")); + + let parsed: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.id, "call-123"); + assert_eq!(parsed.function.name, "search"); +} + +#[test] +fn test_partial_json_parser() { + let parser = PartialJson::default(); + + // Test complete JSON + let input = r#"{"name": "test", "value": 42}"#; + let (value, consumed) = parser.parse_value(input).unwrap(); + assert_eq!(value["name"], "test"); + assert_eq!(value["value"], 42); + assert_eq!(consumed, input.len()); + + // Test incomplete JSON object + let input = r#"{"name": "test", "value": "#; + let (value, _consumed) = parser.parse_value(input).unwrap(); + assert_eq!(value["name"], "test"); + assert!(value["value"].is_null()); + + // Test incomplete string + let input = r#"{"name": "tes"#; + let (value, _consumed) = parser.parse_value(input).unwrap(); + assert_eq!(value["name"], "tes"); + + // Test incomplete array + let input = r#"[1, 2, "#; + let (value, _consumed) = parser.parse_value(input).unwrap(); + assert!(value.is_array()); + assert_eq!(value[0], 1); + assert_eq!(value[1], 2); +} + +#[test] +fn test_partial_json_depth_limit() { + // max_depth of 3 allows nesting up to 3 levels + // Set allow_incomplete to false to get errors instead of partial results + let parser = PartialJson::new(3, false); + + // This should work (simple object) + let input = r#"{"a": 1}"#; + let result = parser.parse_value(input); + assert!(result.is_ok()); + + // This should work (nested to depth 3) + let input = r#"{"a": {"b": {"c": 1}}}"#; + let result = parser.parse_value(input); + assert!(result.is_ok()); + + // This should fail (nested to depth 4, exceeds limit) + let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#; + let result = parser.parse_value(input); + assert!(result.is_err()); +} + +#[test] +fn test_is_complete_json() { + assert!(is_complete_json(r#"{"name": "test"}"#)); + assert!(is_complete_json(r#"[1, 2, 3]"#)); + assert!(is_complete_json(r#""string""#)); + assert!(is_complete_json("42")); + assert!(is_complete_json("true")); + assert!(is_complete_json("null")); + + assert!(!is_complete_json(r#"{"name": "#)); + assert!(!is_complete_json(r#"[1, 2, "#)); + assert!(!is_complete_json(r#""unclosed"#)); +} + +#[test] +fn test_find_common_prefix() { + assert_eq!(find_common_prefix("hello", "hello"), 5); + assert_eq!(find_common_prefix("hello", "help"), 3); + assert_eq!(find_common_prefix("hello", "world"), 0); + assert_eq!(find_common_prefix("", "hello"), 0); + assert_eq!(find_common_prefix("hello", ""), 0); +} + +#[test] +fn test_compute_diff() { + assert_eq!(compute_diff("hello", "hello world"), " world"); + assert_eq!(compute_diff("", "hello"), "hello"); + assert_eq!(compute_diff("hello", "hello"), ""); + assert_eq!(compute_diff("test", "hello"), "hello"); +} + +#[test] +fn test_stream_result_variants() { + // Test Incomplete + let result = StreamResult::Incomplete; + matches!(result, StreamResult::Incomplete); + + // Test ToolName + let result = StreamResult::ToolName { + index: 0, + name: "test".to_string(), + }; + if let StreamResult::ToolName { index, name } = result { + assert_eq!(index, 0); + assert_eq!(name, "test"); + } else { + panic!("Expected ToolName variant"); + } + + // Test ToolComplete + let tool = ToolCall { + id: "123".to_string(), + r#type: "function".to_string(), + function: FunctionCall { + name: "test".to_string(), + arguments: "{}".to_string(), + }, + }; + let result = StreamResult::ToolComplete(tool.clone()); + if let StreamResult::ToolComplete(t) = result { + assert_eq!(t.id, "123"); + } else { + panic!("Expected ToolComplete variant"); + } +} + +#[test] +fn test_partial_tool_call() { + let mut partial = PartialToolCall { + name: None, + arguments_buffer: String::new(), + start_position: 0, + name_sent: false, + streamed_args: String::new(), + }; + + // Set name + partial.name = Some("test_function".to_string()); + assert_eq!(partial.name.as_ref().unwrap(), "test_function"); + + // Append arguments + partial.arguments_buffer.push_str(r#"{"key": "value"}"#); + assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#); + + // Update streaming state + partial.name_sent = true; + partial.streamed_args = r#"{"key": "#.to_string(); + assert!(partial.name_sent); + assert_eq!(partial.streamed_args, r#"{"key": "#); +} diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs new file mode 100644 index 000000000..19263688d --- /dev/null +++ b/sgl-router/src/tool_parser/traits.rs @@ -0,0 +1,35 @@ +use crate::tool_parser::{ + errors::ToolParserResult, + state::ParseState, + types::{StreamResult, ToolCall}, +}; +use async_trait::async_trait; + +/// Core trait for all tool parsers +#[async_trait] +pub trait ToolParser: Send + Sync { + /// Parse complete tool calls from final output + async fn parse_complete(&self, output: &str) -> ToolParserResult>; + + /// Parse tool calls from model output (streaming) + async fn parse_incremental( + &self, + chunk: &str, + state: &mut ParseState, + ) -> ToolParserResult; + + /// Check if text contains tool calls in this parser's format + fn detect_format(&self, text: &str) -> bool; +} + +/// Trait for partial JSON parsing +pub trait PartialJsonParser: Send + Sync { + /// Parse potentially incomplete JSON + fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>; + + /// Check if JSON is complete + fn is_complete(&self, input: &str) -> bool; + + /// Get the maximum parsing depth + fn max_depth(&self) -> usize; +} diff --git a/sgl-router/src/tool_parser/types.rs b/sgl-router/src/tool_parser/types.rs new file mode 100644 index 000000000..0638d1c2a --- /dev/null +++ b/sgl-router/src/tool_parser/types.rs @@ -0,0 +1,73 @@ +use serde::{Deserialize, Serialize}; + +/// Parsed tool call from model output (OpenAI format) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolCall { + /// Unique identifier for the tool call + pub id: String, + /// Type of tool call (currently always "function") + #[serde(rename = "type")] + pub r#type: String, + /// Function call details + pub function: FunctionCall, +} + +/// Function call within a tool call +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FunctionCall { + /// Name of the function to call + pub name: String, + /// Arguments as JSON string + pub arguments: String, +} + +/// Streaming parse result +#[derive(Debug, Clone)] +pub enum StreamResult { + /// Need more data to continue parsing + Incomplete, + /// Found a tool name (for streaming) + ToolName { index: usize, name: String }, + /// Found incremental arguments (for streaming) + ToolArguments { index: usize, arguments: String }, + /// Completed parsing a tool + ToolComplete(ToolCall), + /// Normal text (not part of tool call) + NormalText(String), +} + +/// Token configuration for parsing +#[derive(Debug, Clone)] +pub struct TokenConfig { + /// Start tokens for tool calls + pub start_tokens: Vec, + /// End tokens for tool calls + pub end_tokens: Vec, + /// Separator between multiple tool calls + pub separator: String, +} + +impl TokenConfig { + /// Iterate over start/end token pairs + pub fn iter_pairs(&self) -> impl Iterator { + self.start_tokens + .iter() + .zip(self.end_tokens.iter()) + .map(|(s, e)| (s.as_str(), e.as_str())) + } +} + +/// Simple partial tool call for streaming +#[derive(Debug, Clone)] +pub struct PartialToolCall { + /// Tool name (if parsed) + pub name: Option, + /// Buffer for accumulating arguments + pub arguments_buffer: String, + /// Start position in the input buffer + pub start_position: usize, + /// Whether the name has been sent (for streaming) + pub name_sent: bool, + /// Arguments already streamed + pub streamed_args: String, +}