[router] add tool parser base structure and partial json parser (#9482)
This commit is contained in:
@@ -48,6 +48,7 @@ metrics = "0.24.2"
|
||||
metrics-exporter-prometheus = "0.17.0"
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
thiserror = "2.0.12"
|
||||
regex = "1.10"
|
||||
url = "2.5.4"
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
|
||||
@@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
|
||||
// Get token count once
|
||||
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
|
||||
let encoding = tokenizer.encode(prompt).unwrap();
|
||||
let token_count = encoding.token_ids().len();
|
||||
|
||||
// Track if metrics have been printed for this test case
|
||||
let printed = Arc::new(AtomicBool::new(false));
|
||||
@@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) {
|
||||
let batch_sizes = vec![1, 8, 16, 32, 64, 128];
|
||||
let prompt = MEDIUM_PROMPT;
|
||||
let prompt_len = prompt.len();
|
||||
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
|
||||
let encoding = tokenizer.encode(prompt).unwrap();
|
||||
let token_count = encoding.token_ids().len();
|
||||
|
||||
let mut group = c.benchmark_group("batch_encode");
|
||||
|
||||
@@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) {
|
||||
);
|
||||
|
||||
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
|
||||
let tokens = tokenizer.encode(&test_text).unwrap().token_ids();
|
||||
let encoding = tokenizer.encode(&test_text).unwrap();
|
||||
let tokens = encoding.token_ids();
|
||||
let num_tokens = tokens.len();
|
||||
|
||||
let mut group = c.benchmark_group("decode_performance");
|
||||
@@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) {
|
||||
group.bench_function("direct_decode", |b| {
|
||||
let printed = printed_direct.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = tokens.clone();
|
||||
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _ in 0..iters {
|
||||
black_box(tokenizer.decode(&tokens, false).unwrap());
|
||||
black_box(tokenizer.decode(tokens, false).unwrap());
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
@@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) {
|
||||
group.bench_function("decode_stream", |b| {
|
||||
let printed = printed_stream.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = tokens.clone();
|
||||
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _ in 0..iters {
|
||||
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
|
||||
let mut output = String::new();
|
||||
for token in &tokens {
|
||||
for token in tokens {
|
||||
if let Some(text) = decoder.step(*token).unwrap() {
|
||||
output.push_str(&text);
|
||||
}
|
||||
@@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) {
|
||||
group.bench_function("sequence_decode", |b| {
|
||||
let printed = printed_seq.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = tokens.clone();
|
||||
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _ in 0..iters {
|
||||
let mut sequence = Sequence::new(tokenizer.clone());
|
||||
let mut output = String::new();
|
||||
for token in &tokens {
|
||||
for token in tokens {
|
||||
let text = sequence.append_token(*token).unwrap();
|
||||
output.push_str(&text);
|
||||
}
|
||||
@@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
||||
);
|
||||
|
||||
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(1000);
|
||||
let all_tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
|
||||
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||
let all_tokens = encoding.token_ids();
|
||||
|
||||
let mut group = c.benchmark_group("streaming_100k");
|
||||
group.measurement_time(Duration::from_secs(1));
|
||||
@@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
||||
group.bench_function("decode_stream_100k", |b| {
|
||||
let printed = printed_stream.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = all_tokens.clone();
|
||||
|
||||
b.iter_custom(|_iters| {
|
||||
let start = Instant::now();
|
||||
@@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
||||
let mut output = String::new();
|
||||
let mut tokens_processed = 0u64;
|
||||
|
||||
for token in tokens.iter().cycle() {
|
||||
for token in all_tokens.iter().cycle() {
|
||||
if start.elapsed() >= Duration::from_millis(500) {
|
||||
break;
|
||||
}
|
||||
@@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
||||
group.bench_function("sequence_100k", |b| {
|
||||
let printed = printed_seq.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = all_tokens.clone();
|
||||
|
||||
b.iter_custom(|_iters| {
|
||||
let start = Instant::now();
|
||||
@@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
||||
let mut output = String::new();
|
||||
let mut tokens_processed = 0u64;
|
||||
|
||||
for token in tokens.iter().cycle() {
|
||||
for token in all_tokens.iter().cycle() {
|
||||
if start.elapsed() >= Duration::from_millis(500) {
|
||||
break;
|
||||
}
|
||||
@@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) {
|
||||
let tokens_per_sequence = 10_000;
|
||||
|
||||
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
||||
let token_batch = tokenizer.encode(&sample_text).unwrap().token_ids();
|
||||
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||
let token_batch: Vec<u32> = encoding.token_ids().to_vec();
|
||||
|
||||
let mut group = c.benchmark_group("concurrent_streaming");
|
||||
group.measurement_time(Duration::from_secs(2));
|
||||
@@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
||||
.with_stop_token(2);
|
||||
|
||||
let sample_text = "Hello world! This is a test. ### Stop here. Continue after.".repeat(100);
|
||||
let tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
|
||||
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||
let tokens = encoding.token_ids();
|
||||
|
||||
let mut group = c.benchmark_group("stop_sequences");
|
||||
|
||||
@@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
||||
group.bench_function("no_stops", |b| {
|
||||
let printed_clone = printed_no_stop.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = tokens.clone();
|
||||
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
@@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
||||
StopSequenceConfig::default(),
|
||||
false,
|
||||
);
|
||||
for token in &tokens {
|
||||
for token in tokens {
|
||||
let _ = decoder.process_token(*token).unwrap();
|
||||
total_tokens += 1;
|
||||
}
|
||||
@@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
||||
group.bench_function("with_stops", |b| {
|
||||
let printed_clone = printed_with_stops.clone();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let tokens = tokens.clone();
|
||||
let config = config.clone();
|
||||
|
||||
b.iter_custom(|iters| {
|
||||
@@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
||||
StopSequenceDecoder::new(tokenizer.clone(), config.clone(), false);
|
||||
let mut sequence_tokens = 0u64;
|
||||
|
||||
for token in &tokens {
|
||||
for token in tokens {
|
||||
let result = decoder.process_token(*token).unwrap();
|
||||
sequence_tokens += 1;
|
||||
|
||||
@@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) {
|
||||
|
||||
// Generate tokens for decoding
|
||||
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
||||
let test_tokens = tokenizer.encode(&test_text).unwrap().token_ids();
|
||||
let encoding = tokenizer.encode(&test_text).unwrap();
|
||||
let test_tokens: Vec<u32> = encoding.token_ids().to_vec();
|
||||
|
||||
let mut group = c.benchmark_group("multithreaded_decode");
|
||||
group.measurement_time(Duration::from_secs(2));
|
||||
@@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _ in 0..iters {
|
||||
let _ = black_box(encoding.token_ids_ref());
|
||||
let _ = black_box(encoding.token_ids());
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ pub mod routers;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
pub mod tokenizer;
|
||||
pub mod tool_parser;
|
||||
pub mod tree;
|
||||
use crate::metrics::PrometheusConfig;
|
||||
|
||||
|
||||
32
sgl-router/src/tool_parser/errors.rs
Normal file
32
sgl-router/src/tool_parser/errors.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type for tool parser operations
|
||||
pub type ToolParserResult<T> = Result<T, ToolParserError>;
|
||||
|
||||
/// Errors that can occur during tool parsing
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ToolParserError {
|
||||
#[error("Parsing failed: {0}")]
|
||||
ParsingFailed(String),
|
||||
|
||||
#[error("Model not supported: {0}")]
|
||||
ModelNotSupported(String),
|
||||
|
||||
#[error("Parse depth exceeded: max {0}")]
|
||||
DepthExceeded(usize),
|
||||
|
||||
#[error("Invalid JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Regex error: {0}")]
|
||||
RegexError(#[from] regex::Error),
|
||||
|
||||
#[error("Incomplete tool call")]
|
||||
Incomplete,
|
||||
|
||||
#[error("Invalid tool name: {0}")]
|
||||
InvalidToolName(String),
|
||||
|
||||
#[error("Token not found: {0}")]
|
||||
TokenNotFound(String),
|
||||
}
|
||||
20
sgl-router/src/tool_parser/mod.rs
Normal file
20
sgl-router/src/tool_parser/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
/// Tool parser module for handling function/tool calls in model outputs
|
||||
///
|
||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||
/// Phase 1 focuses on core infrastructure: types, traits, registry, and partial JSON parsing.
|
||||
pub mod errors;
|
||||
pub mod partial_json;
|
||||
pub mod registry;
|
||||
pub mod state;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use errors::{ToolParserError, ToolParserResult};
|
||||
pub use registry::ParserRegistry;
|
||||
pub use state::{ParsePhase, ParseState};
|
||||
pub use traits::{PartialJsonParser, ToolParser};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
|
||||
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
@@ -0,0 +1,527 @@
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
traits::PartialJsonParser,
|
||||
};
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
/// Parser for incomplete JSON
|
||||
pub struct PartialJson {
|
||||
/// Maximum depth for nested structures
|
||||
max_depth: usize,
|
||||
/// Whether to allow incomplete values
|
||||
allow_incomplete: bool,
|
||||
}
|
||||
|
||||
impl PartialJson {
|
||||
/// Create a new partial JSON parser
|
||||
pub fn new(max_depth: usize, allow_incomplete: bool) -> Self {
|
||||
Self {
|
||||
max_depth,
|
||||
allow_incomplete,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
|
||||
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
|
||||
let value = parser.parse_value(0)?;
|
||||
Ok((value, parser.position))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PartialJson {
|
||||
fn default() -> Self {
|
||||
Self::new(32, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialJsonParser for PartialJson {
|
||||
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||
self.parse_value(input)
|
||||
}
|
||||
|
||||
fn is_complete(&self, input: &str) -> bool {
|
||||
// Try to parse as complete JSON
|
||||
serde_json::from_str::<Value>(input).is_ok()
|
||||
}
|
||||
|
||||
fn max_depth(&self) -> usize {
|
||||
self.max_depth
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal parser state
|
||||
struct Parser<'a> {
|
||||
chars: std::iter::Peekable<std::str::Chars<'a>>,
|
||||
position: usize,
|
||||
max_depth: usize,
|
||||
allow_incomplete: bool,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
|
||||
Self {
|
||||
chars: input.chars().peekable(),
|
||||
position: 0,
|
||||
max_depth,
|
||||
allow_incomplete,
|
||||
}
|
||||
}
|
||||
|
||||
fn peek(&mut self) -> Option<char> {
|
||||
self.chars.peek().copied()
|
||||
}
|
||||
|
||||
fn advance(&mut self) {
|
||||
if self.chars.next().is_some() {
|
||||
self.position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn skip_whitespace(&mut self) {
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_whitespace() {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some('{') => self.parse_object(depth + 1),
|
||||
Some('[') => self.parse_array(depth + 1),
|
||||
Some('"') => self.parse_string(),
|
||||
Some('t') | Some('f') => self.parse_bool(),
|
||||
Some('n') => self.parse_null(),
|
||||
Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(),
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
Ok(Value::Null)
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(
|
||||
"Unexpected character".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
let mut object = Map::new();
|
||||
|
||||
// Consume '{'
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Check for empty object
|
||||
if self.peek() == Some('}') {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse key
|
||||
let key = match self.parse_string() {
|
||||
Ok(Value::String(s)) => s,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
// Incomplete object
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
|
||||
};
|
||||
|
||||
self.skip_whitespace();
|
||||
|
||||
// Expect ':'
|
||||
if self.peek() != Some(':') {
|
||||
if self.allow_incomplete {
|
||||
// Add null value for incomplete pair
|
||||
object.insert(key, Value::Null);
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
|
||||
}
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Parse value (keep same depth - we already incremented in parse_object)
|
||||
let value = match self.parse_value(depth) {
|
||||
Ok(v) => v,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
// Add null for incomplete value
|
||||
object.insert(key, Value::Null);
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
object.insert(key, value);
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
// Check for trailing comma
|
||||
if self.peek() == Some('}') {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
}
|
||||
Some('}') => {
|
||||
self.advance();
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
None if self.allow_incomplete => {
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
return Ok(Value::Object(object));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||
if depth > self.max_depth {
|
||||
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||
}
|
||||
|
||||
let mut array = Vec::new();
|
||||
|
||||
// Consume '['
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
|
||||
// Check for empty array
|
||||
if self.peek() == Some(']') {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
|
||||
loop {
|
||||
// Parse value (keep same depth - we already incremented in parse_object)
|
||||
let value = match self.parse_value(depth) {
|
||||
Ok(v) => v,
|
||||
Err(_) if self.allow_incomplete => {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
array.push(value);
|
||||
self.skip_whitespace();
|
||||
|
||||
match self.peek() {
|
||||
Some(',') => {
|
||||
self.advance();
|
||||
self.skip_whitespace();
|
||||
// Check for trailing comma
|
||||
if self.peek() == Some(']') {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
}
|
||||
Some(']') => {
|
||||
self.advance();
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
None if self.allow_incomplete => {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
_ => {
|
||||
if self.allow_incomplete {
|
||||
return Ok(Value::Array(array));
|
||||
}
|
||||
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_string(&mut self) -> ToolParserResult<Value> {
|
||||
if self.peek() != Some('"') {
|
||||
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
|
||||
}
|
||||
|
||||
// Consume opening quote
|
||||
self.advance();
|
||||
|
||||
let mut string = String::new();
|
||||
let mut escaped = false;
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if escaped {
|
||||
// Handle escape sequences
|
||||
let escaped_char = match ch {
|
||||
'"' | '\\' | '/' => ch,
|
||||
'b' => '\u{0008}',
|
||||
'f' => '\u{000C}',
|
||||
'n' => '\n',
|
||||
'r' => '\r',
|
||||
't' => '\t',
|
||||
'u' => {
|
||||
// Unicode escape
|
||||
self.advance();
|
||||
let hex = self.parse_unicode_escape()?;
|
||||
string.push(hex);
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
_ => ch, // Invalid escape, but be lenient
|
||||
};
|
||||
string.push(escaped_char);
|
||||
escaped = false;
|
||||
} else if ch == '\\' {
|
||||
escaped = true;
|
||||
} else if ch == '"' {
|
||||
// End of string
|
||||
self.advance();
|
||||
return Ok(Value::String(string));
|
||||
} else {
|
||||
string.push(ch);
|
||||
}
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Incomplete string
|
||||
if self.allow_incomplete {
|
||||
Ok(Value::String(string))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
|
||||
let mut hex = String::new();
|
||||
for _ in 0..4 {
|
||||
if let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_hexdigit() {
|
||||
hex.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hex.len() == 4 {
|
||||
u32::from_str_radix(&hex, 16)
|
||||
.ok()
|
||||
.and_then(char::from_u32)
|
||||
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
|
||||
} else if self.allow_incomplete {
|
||||
Ok('\u{FFFD}') // Replacement character
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed(
|
||||
"Incomplete unicode escape".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_number(&mut self) -> ToolParserResult<Value> {
|
||||
let mut number = String::new();
|
||||
|
||||
// Handle negative sign
|
||||
if self.peek() == Some('-') {
|
||||
number.push('-');
|
||||
self.advance();
|
||||
}
|
||||
|
||||
// Parse integer part
|
||||
if self.peek() == Some('0') {
|
||||
number.push('0');
|
||||
self.advance();
|
||||
} else {
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse decimal part
|
||||
if self.peek() == Some('.') {
|
||||
number.push('.');
|
||||
self.advance();
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse exponent
|
||||
if let Some(ch) = self.peek() {
|
||||
if ch == 'e' || ch == 'E' {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
|
||||
if let Some(sign) = self.peek() {
|
||||
if sign == '+' || sign == '-' {
|
||||
number.push(sign);
|
||||
self.advance();
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_ascii_digit() {
|
||||
number.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as integer first, then as float
|
||||
if let Ok(n) = number.parse::<i64>() {
|
||||
Ok(Value::Number(serde_json::Number::from(n)))
|
||||
} else if let Ok(n) = number.parse::<f64>() {
|
||||
Ok(Value::Number(
|
||||
serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)),
|
||||
))
|
||||
} else if self.allow_incomplete {
|
||||
Ok(Value::Number(serde_json::Number::from(0)))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid number".into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bool(&mut self) -> ToolParserResult<Value> {
|
||||
let mut word = String::new();
|
||||
|
||||
// Peek at upcoming characters to validate it looks like a boolean
|
||||
let mut temp_chars = self.chars.clone();
|
||||
while let Some(&ch) = temp_chars.peek() {
|
||||
if ch.is_alphabetic() && word.len() < 5 {
|
||||
// "false" is 5 chars
|
||||
word.push(ch);
|
||||
temp_chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a valid boolean prefix
|
||||
let is_valid = word == "true"
|
||||
|| word == "false"
|
||||
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
|
||||
|
||||
if !is_valid {
|
||||
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
|
||||
}
|
||||
|
||||
// Now actually consume the characters
|
||||
word.clear();
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_alphabetic() {
|
||||
word.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
match word.as_str() {
|
||||
"true" => Ok(Value::Bool(true)),
|
||||
"false" => Ok(Value::Bool(false)),
|
||||
partial if self.allow_incomplete => {
|
||||
if "true".starts_with(partial) {
|
||||
Ok(Value::Bool(true))
|
||||
} else if "false".starts_with(partial) {
|
||||
Ok(Value::Bool(false))
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
|
||||
}
|
||||
}
|
||||
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_null(&mut self) -> ToolParserResult<Value> {
|
||||
let mut word = String::new();
|
||||
|
||||
// Peek at upcoming characters to validate it looks like "null"
|
||||
let mut temp_chars = self.chars.clone();
|
||||
while let Some(&ch) = temp_chars.peek() {
|
||||
if ch.is_alphabetic() && word.len() < 4 {
|
||||
// "null" is 4 chars
|
||||
word.push(ch);
|
||||
temp_chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a valid null prefix
|
||||
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
|
||||
|
||||
if !is_valid {
|
||||
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
|
||||
}
|
||||
|
||||
// Now actually consume the characters
|
||||
word.clear();
|
||||
while let Some(ch) = self.peek() {
|
||||
if ch.is_alphabetic() {
|
||||
word.push(ch);
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
|
||||
Ok(Value::Null)
|
||||
} else {
|
||||
Err(ToolParserError::ParsingFailed("Invalid null".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to check if a string contains complete JSON
|
||||
pub fn is_complete_json(input: &str) -> bool {
|
||||
serde_json::from_str::<Value>(input).is_ok()
|
||||
}
|
||||
|
||||
/// Utility function to find common prefix between two strings
|
||||
pub fn find_common_prefix(s1: &str, s2: &str) -> usize {
|
||||
s1.chars()
|
||||
.zip(s2.chars())
|
||||
.take_while(|(a, b)| a == b)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Utility function to compute diff between old and new strings
|
||||
pub fn compute_diff(old: &str, new: &str) -> String {
|
||||
let common_len = find_common_prefix(old, new);
|
||||
// Convert character count to byte offset
|
||||
new.chars().skip(common_len).collect()
|
||||
}
|
||||
119
sgl-router/src/tool_parser/registry.rs
Normal file
119
sgl-router/src/tool_parser/registry.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Registry for tool parsers and model mappings
|
||||
pub struct ParserRegistry {
|
||||
/// Map of parser name to parser instance
|
||||
parsers: HashMap<String, Arc<dyn ToolParser>>,
|
||||
/// Map of model name/pattern to parser name
|
||||
model_mapping: HashMap<String, String>,
|
||||
/// Default parser to use when no match found
|
||||
default_parser: String,
|
||||
}
|
||||
|
||||
impl ParserRegistry {
|
||||
/// Create a new parser registry with default mappings
|
||||
pub fn new() -> Self {
|
||||
let mut registry = Self {
|
||||
parsers: HashMap::new(),
|
||||
model_mapping: HashMap::new(),
|
||||
default_parser: "json".to_string(),
|
||||
};
|
||||
|
||||
// Register default model mappings
|
||||
registry.register_default_mappings();
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
/// Register a parser
|
||||
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
|
||||
self.parsers.insert(name.into(), parser);
|
||||
}
|
||||
|
||||
/// Map a model name/pattern to a parser
|
||||
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
|
||||
self.model_mapping.insert(model.into(), parser.into());
|
||||
}
|
||||
|
||||
/// Get parser for a specific model
|
||||
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
|
||||
// Try exact match first
|
||||
if let Some(parser_name) = self.model_mapping.get(model) {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
|
||||
for (pattern, parser_name) in &self.model_mapping {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
if model.starts_with(prefix) {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default parser if it exists
|
||||
self.parsers.get(&self.default_parser).cloned()
|
||||
}
|
||||
|
||||
/// List all registered parsers
|
||||
pub fn list_parsers(&self) -> Vec<&str> {
|
||||
self.parsers.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// List all model mappings
|
||||
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
|
||||
self.model_mapping
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Register default model mappings
|
||||
fn register_default_mappings(&mut self) {
|
||||
// OpenAI models
|
||||
self.map_model("gpt-4*", "json");
|
||||
self.map_model("gpt-3.5*", "json");
|
||||
self.map_model("gpt-4o*", "json");
|
||||
|
||||
// Anthropic models
|
||||
self.map_model("claude-*", "json");
|
||||
|
||||
// Mistral models
|
||||
self.map_model("mistral-*", "mistral");
|
||||
self.map_model("mixtral-*", "mistral");
|
||||
|
||||
// Qwen models
|
||||
self.map_model("qwen*", "qwen");
|
||||
|
||||
// Llama models
|
||||
self.map_model("llama-*", "llama");
|
||||
self.map_model("meta-llama-*", "llama");
|
||||
|
||||
// Other models default to JSON
|
||||
self.map_model("gemini-*", "json");
|
||||
self.map_model("palm-*", "json");
|
||||
}
|
||||
|
||||
/// Set the default parser
|
||||
pub fn set_default_parser(&mut self, name: impl Into<String>) {
|
||||
self.default_parser = name.into();
|
||||
}
|
||||
|
||||
/// Check if a parser is registered
|
||||
pub fn has_parser(&self, name: &str) -> bool {
|
||||
self.parsers.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParserRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
181
sgl-router/src/tool_parser/state.rs
Normal file
181
sgl-router/src/tool_parser/state.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use crate::tool_parser::types::{PartialToolCall, ToolCall};
|
||||
|
||||
/// Current phase of parsing
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ParsePhase {
|
||||
/// Looking for start of tool call
|
||||
Searching,
|
||||
/// Parsing function name
|
||||
InName,
|
||||
/// Parsing function arguments
|
||||
InArguments,
|
||||
/// Tool call complete
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// State for streaming parser
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParseState {
|
||||
/// Buffer for accumulating input
|
||||
pub buffer: String,
|
||||
/// Position of last consumed character
|
||||
pub consumed: usize,
|
||||
/// Current partial tool being parsed
|
||||
pub partial_tool: Option<PartialToolCall>,
|
||||
/// Completed tool calls
|
||||
pub completed_tools: Vec<ToolCall>,
|
||||
/// Current parsing phase
|
||||
pub phase: ParsePhase,
|
||||
/// Bracket/brace depth for JSON parsing
|
||||
pub bracket_depth: i32,
|
||||
/// Whether currently inside a string literal
|
||||
pub in_string: bool,
|
||||
/// Whether next character should be escaped
|
||||
pub escape_next: bool,
|
||||
/// Current tool index (for streaming)
|
||||
pub tool_index: usize,
|
||||
}
|
||||
|
||||
impl ParseState {
|
||||
/// Create a new parse state
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
consumed: 0,
|
||||
partial_tool: None,
|
||||
completed_tools: Vec::new(),
|
||||
phase: ParsePhase::Searching,
|
||||
bracket_depth: 0,
|
||||
in_string: false,
|
||||
escape_next: false,
|
||||
tool_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset state for parsing next tool
|
||||
pub fn reset(&mut self) {
|
||||
self.partial_tool = None;
|
||||
self.phase = ParsePhase::Searching;
|
||||
self.bracket_depth = 0;
|
||||
self.in_string = false;
|
||||
self.escape_next = false;
|
||||
}
|
||||
|
||||
/// Process a single character for JSON parsing
|
||||
pub fn process_char(&mut self, ch: char) {
|
||||
// Handle escape sequences
|
||||
if self.escape_next {
|
||||
self.escape_next = false;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
if ch == '\\' && self.in_string {
|
||||
self.escape_next = true;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
// Track string boundaries
|
||||
if ch == '"' && !self.escape_next {
|
||||
self.in_string = !self.in_string;
|
||||
}
|
||||
|
||||
// Track bracket depth for JSON
|
||||
if !self.in_string {
|
||||
match ch {
|
||||
'{' | '[' => {
|
||||
self.bracket_depth += 1;
|
||||
}
|
||||
'}' | ']' => {
|
||||
self.bracket_depth -= 1;
|
||||
if self.bracket_depth == 0 && self.partial_tool.is_some() {
|
||||
// Complete tool call found
|
||||
self.phase = ParsePhase::Complete;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.push(ch);
|
||||
}
|
||||
|
||||
/// Check if we have a complete JSON object/array
|
||||
pub fn has_complete_json(&self) -> bool {
|
||||
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Extract content from buffer starting at position
|
||||
pub fn extract_from(&self, start: usize) -> &str {
|
||||
if start >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after start
|
||||
let mut safe_start = start;
|
||||
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
|
||||
safe_start += 1;
|
||||
}
|
||||
|
||||
if safe_start < self.buffer.len() {
|
||||
&self.buffer[safe_start..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark content as consumed up to position
|
||||
pub fn consume_to(&mut self, position: usize) {
|
||||
if position > self.consumed {
|
||||
self.consumed = position;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get unconsumed content
|
||||
pub fn unconsumed(&self) -> &str {
|
||||
if self.consumed >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed += 1;
|
||||
}
|
||||
|
||||
if safe_consumed < self.buffer.len() {
|
||||
&self.buffer[safe_consumed..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear consumed content from buffer
|
||||
pub fn clear_consumed(&mut self) {
|
||||
if self.consumed > 0 {
|
||||
// Find the nearest character boundary at or before consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed -= 1;
|
||||
}
|
||||
|
||||
if safe_consumed > 0 {
|
||||
self.buffer.drain(..safe_consumed);
|
||||
self.consumed = self.consumed.saturating_sub(safe_consumed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add completed tool
|
||||
pub fn add_completed_tool(&mut self, tool: ToolCall) {
|
||||
self.completed_tools.push(tool);
|
||||
self.tool_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParseState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
249
sgl-router/src/tool_parser/tests.rs
Normal file
249
sgl-router/src/tool_parser/tests.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use super::*;
|
||||
use crate::tool_parser::partial_json::{
|
||||
compute_diff, find_common_prefix, is_complete_json, PartialJson,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_new() {
|
||||
let state = ParseState::new();
|
||||
assert_eq!(state.phase, ParsePhase::Searching);
|
||||
assert_eq!(state.buffer, "");
|
||||
assert_eq!(state.consumed, 0);
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
assert!(!state.in_string);
|
||||
assert!(!state.escape_next);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_process_char() {
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test bracket tracking
|
||||
state.process_char('{');
|
||||
assert_eq!(state.bracket_depth, 1);
|
||||
|
||||
state.process_char('}');
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
|
||||
// Test string tracking
|
||||
state.process_char('"');
|
||||
assert!(state.in_string);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.in_string);
|
||||
|
||||
// Test escape handling
|
||||
state.process_char('"');
|
||||
state.process_char('\\');
|
||||
assert!(state.escape_next);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.escape_next);
|
||||
assert!(state.in_string); // Still in string because quote was escaped
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_config() {
|
||||
let config = TokenConfig {
|
||||
start_tokens: vec!["<start>".to_string(), "[".to_string()],
|
||||
end_tokens: vec!["</end>".to_string(), "]".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
};
|
||||
|
||||
let pairs: Vec<_> = config.iter_pairs().collect();
|
||||
assert_eq!(pairs.len(), 2);
|
||||
assert_eq!(pairs[0], ("<start>", "</end>"));
|
||||
assert_eq!(pairs[1], ("[", "]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Test has default mappings
|
||||
assert!(!registry.list_mappings().is_empty());
|
||||
|
||||
// Test model pattern matching
|
||||
let mappings = registry.list_mappings();
|
||||
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||
assert!(has_gpt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry_pattern_matching() {
|
||||
let mut registry = ParserRegistry::new();
|
||||
|
||||
// Test that model mappings work by checking the list
|
||||
registry.map_model("test-model", "json");
|
||||
|
||||
// Verify through list_mappings
|
||||
let mappings = registry.list_mappings();
|
||||
let has_test = mappings
|
||||
.iter()
|
||||
.any(|(m, p)| *m == "test-model" && *p == "json");
|
||||
assert!(has_test);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call-123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "search".to_string(),
|
||||
arguments: r#"{"query": "rust programming"}"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
assert!(json.contains("call-123"));
|
||||
assert!(json.contains("search"));
|
||||
assert!(json.contains("rust programming"));
|
||||
|
||||
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, "call-123");
|
||||
assert_eq!(parsed.function.name, "search");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_json_parser() {
|
||||
let parser = PartialJson::default();
|
||||
|
||||
// Test complete JSON
|
||||
let input = r#"{"name": "test", "value": 42}"#;
|
||||
let (value, consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert_eq!(value["value"], 42);
|
||||
assert_eq!(consumed, input.len());
|
||||
|
||||
// Test incomplete JSON object
|
||||
let input = r#"{"name": "test", "value": "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "test");
|
||||
assert!(value["value"].is_null());
|
||||
|
||||
// Test incomplete string
|
||||
let input = r#"{"name": "tes"#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert_eq!(value["name"], "tes");
|
||||
|
||||
// Test incomplete array
|
||||
let input = r#"[1, 2, "#;
|
||||
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||
assert!(value.is_array());
|
||||
assert_eq!(value[0], 1);
|
||||
assert_eq!(value[1], 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_json_depth_limit() {
|
||||
// max_depth of 3 allows nesting up to 3 levels
|
||||
// Set allow_incomplete to false to get errors instead of partial results
|
||||
let parser = PartialJson::new(3, false);
|
||||
|
||||
// This should work (simple object)
|
||||
let input = r#"{"a": 1}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should work (nested to depth 3)
|
||||
let input = r#"{"a": {"b": {"c": 1}}}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should fail (nested to depth 4, exceeds limit)
|
||||
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
|
||||
let result = parser.parse_value(input);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_complete_json() {
|
||||
assert!(is_complete_json(r#"{"name": "test"}"#));
|
||||
assert!(is_complete_json(r#"[1, 2, 3]"#));
|
||||
assert!(is_complete_json(r#""string""#));
|
||||
assert!(is_complete_json("42"));
|
||||
assert!(is_complete_json("true"));
|
||||
assert!(is_complete_json("null"));
|
||||
|
||||
assert!(!is_complete_json(r#"{"name": "#));
|
||||
assert!(!is_complete_json(r#"[1, 2, "#));
|
||||
assert!(!is_complete_json(r#""unclosed"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_common_prefix() {
|
||||
assert_eq!(find_common_prefix("hello", "hello"), 5);
|
||||
assert_eq!(find_common_prefix("hello", "help"), 3);
|
||||
assert_eq!(find_common_prefix("hello", "world"), 0);
|
||||
assert_eq!(find_common_prefix("", "hello"), 0);
|
||||
assert_eq!(find_common_prefix("hello", ""), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_diff() {
|
||||
assert_eq!(compute_diff("hello", "hello world"), " world");
|
||||
assert_eq!(compute_diff("", "hello"), "hello");
|
||||
assert_eq!(compute_diff("hello", "hello"), "");
|
||||
assert_eq!(compute_diff("test", "hello"), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_result_variants() {
|
||||
// Test Incomplete
|
||||
let result = StreamResult::Incomplete;
|
||||
matches!(result, StreamResult::Incomplete);
|
||||
|
||||
// Test ToolName
|
||||
let result = StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: "test".to_string(),
|
||||
};
|
||||
if let StreamResult::ToolName { index, name } = result {
|
||||
assert_eq!(index, 0);
|
||||
assert_eq!(name, "test");
|
||||
} else {
|
||||
panic!("Expected ToolName variant");
|
||||
}
|
||||
|
||||
// Test ToolComplete
|
||||
let tool = ToolCall {
|
||||
id: "123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "test".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let result = StreamResult::ToolComplete(tool.clone());
|
||||
if let StreamResult::ToolComplete(t) = result {
|
||||
assert_eq!(t.id, "123");
|
||||
} else {
|
||||
panic!("Expected ToolComplete variant");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_tool_call() {
|
||||
let mut partial = PartialToolCall {
|
||||
name: None,
|
||||
arguments_buffer: String::new(),
|
||||
start_position: 0,
|
||||
name_sent: false,
|
||||
streamed_args: String::new(),
|
||||
};
|
||||
|
||||
// Set name
|
||||
partial.name = Some("test_function".to_string());
|
||||
assert_eq!(partial.name.as_ref().unwrap(), "test_function");
|
||||
|
||||
// Append arguments
|
||||
partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
|
||||
assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);
|
||||
|
||||
// Update streaming state
|
||||
partial.name_sent = true;
|
||||
partial.streamed_args = r#"{"key": "#.to_string();
|
||||
assert!(partial.name_sent);
|
||||
assert_eq!(partial.streamed_args, r#"{"key": "#);
|
||||
}
|
||||
35
sgl-router/src/tool_parser/traits.rs
Normal file
35
sgl-router/src/tool_parser/traits.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
types::{StreamResult, ToolCall},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Core trait for all tool parsers
|
||||
#[async_trait]
|
||||
pub trait ToolParser: Send + Sync {
|
||||
/// Parse complete tool calls from final output
|
||||
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
|
||||
|
||||
/// Parse tool calls from model output (streaming)
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult>;
|
||||
|
||||
/// Check if text contains tool calls in this parser's format
|
||||
fn detect_format(&self, text: &str) -> bool;
|
||||
}
|
||||
|
||||
/// Trait for partial JSON parsing
|
||||
pub trait PartialJsonParser: Send + Sync {
|
||||
/// Parse potentially incomplete JSON
|
||||
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
|
||||
|
||||
/// Check if JSON is complete
|
||||
fn is_complete(&self, input: &str) -> bool;
|
||||
|
||||
/// Get the maximum parsing depth
|
||||
fn max_depth(&self) -> usize;
|
||||
}
|
||||
73
sgl-router/src/tool_parser/types.rs
Normal file
73
sgl-router/src/tool_parser/types.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Parsed tool call from model output (OpenAI format)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for the tool call
|
||||
pub id: String,
|
||||
/// Type of tool call (currently always "function")
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
/// Function call details
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
/// Function call within a tool call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
/// Arguments as JSON string
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Streaming parse result
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamResult {
|
||||
/// Need more data to continue parsing
|
||||
Incomplete,
|
||||
/// Found a tool name (for streaming)
|
||||
ToolName { index: usize, name: String },
|
||||
/// Found incremental arguments (for streaming)
|
||||
ToolArguments { index: usize, arguments: String },
|
||||
/// Completed parsing a tool
|
||||
ToolComplete(ToolCall),
|
||||
/// Normal text (not part of tool call)
|
||||
NormalText(String),
|
||||
}
|
||||
|
||||
/// Token configuration for parsing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenConfig {
|
||||
/// Start tokens for tool calls
|
||||
pub start_tokens: Vec<String>,
|
||||
/// End tokens for tool calls
|
||||
pub end_tokens: Vec<String>,
|
||||
/// Separator between multiple tool calls
|
||||
pub separator: String,
|
||||
}
|
||||
|
||||
impl TokenConfig {
|
||||
/// Iterate over start/end token pairs
|
||||
pub fn iter_pairs(&self) -> impl Iterator<Item = (&str, &str)> {
|
||||
self.start_tokens
|
||||
.iter()
|
||||
.zip(self.end_tokens.iter())
|
||||
.map(|(s, e)| (s.as_str(), e.as_str()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartialToolCall {
|
||||
/// Tool name (if parsed)
|
||||
pub name: Option<String>,
|
||||
/// Buffer for accumulating arguments
|
||||
pub arguments_buffer: String,
|
||||
/// Start position in the input buffer
|
||||
pub start_position: usize,
|
||||
/// Whether the name has been sent (for streaming)
|
||||
pub name_sent: bool,
|
||||
/// Arguments already streamed
|
||||
pub streamed_args: String,
|
||||
}
|
||||
Reference in New Issue
Block a user