[router] add tool parser base structure and partial json parser (#9482)
This commit is contained in:
@@ -48,6 +48,7 @@ metrics = "0.24.2"
|
|||||||
metrics-exporter-prometheus = "0.17.0"
|
metrics-exporter-prometheus = "0.17.0"
|
||||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
|
regex = "1.10"
|
||||||
url = "2.5.4"
|
url = "2.5.4"
|
||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
|
|||||||
@@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) {
|
|||||||
let tokenizer_clone = tokenizer.clone();
|
let tokenizer_clone = tokenizer.clone();
|
||||||
|
|
||||||
// Get token count once
|
// Get token count once
|
||||||
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
|
let encoding = tokenizer.encode(prompt).unwrap();
|
||||||
|
let token_count = encoding.token_ids().len();
|
||||||
|
|
||||||
// Track if metrics have been printed for this test case
|
// Track if metrics have been printed for this test case
|
||||||
let printed = Arc::new(AtomicBool::new(false));
|
let printed = Arc::new(AtomicBool::new(false));
|
||||||
@@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) {
|
|||||||
let batch_sizes = vec![1, 8, 16, 32, 64, 128];
|
let batch_sizes = vec![1, 8, 16, 32, 64, 128];
|
||||||
let prompt = MEDIUM_PROMPT;
|
let prompt = MEDIUM_PROMPT;
|
||||||
let prompt_len = prompt.len();
|
let prompt_len = prompt.len();
|
||||||
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
|
let encoding = tokenizer.encode(prompt).unwrap();
|
||||||
|
let token_count = encoding.token_ids().len();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("batch_encode");
|
let mut group = c.benchmark_group("batch_encode");
|
||||||
|
|
||||||
@@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
|
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
|
||||||
let tokens = tokenizer.encode(&test_text).unwrap().token_ids();
|
let encoding = tokenizer.encode(&test_text).unwrap();
|
||||||
|
let tokens = encoding.token_ids();
|
||||||
let num_tokens = tokens.len();
|
let num_tokens = tokens.len();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("decode_performance");
|
let mut group = c.benchmark_group("decode_performance");
|
||||||
@@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) {
|
|||||||
group.bench_function("direct_decode", |b| {
|
group.bench_function("direct_decode", |b| {
|
||||||
let printed = printed_direct.clone();
|
let printed = printed_direct.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _ in 0..iters {
|
for _ in 0..iters {
|
||||||
black_box(tokenizer.decode(&tokens, false).unwrap());
|
black_box(tokenizer.decode(tokens, false).unwrap());
|
||||||
}
|
}
|
||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
|
|
||||||
@@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) {
|
|||||||
group.bench_function("decode_stream", |b| {
|
group.bench_function("decode_stream", |b| {
|
||||||
let printed = printed_stream.clone();
|
let printed = printed_stream.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _ in 0..iters {
|
for _ in 0..iters {
|
||||||
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
|
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
for token in &tokens {
|
for token in tokens {
|
||||||
if let Some(text) = decoder.step(*token).unwrap() {
|
if let Some(text) = decoder.step(*token).unwrap() {
|
||||||
output.push_str(&text);
|
output.push_str(&text);
|
||||||
}
|
}
|
||||||
@@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) {
|
|||||||
group.bench_function("sequence_decode", |b| {
|
group.bench_function("sequence_decode", |b| {
|
||||||
let printed = printed_seq.clone();
|
let printed = printed_seq.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _ in 0..iters {
|
for _ in 0..iters {
|
||||||
let mut sequence = Sequence::new(tokenizer.clone());
|
let mut sequence = Sequence::new(tokenizer.clone());
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
for token in &tokens {
|
for token in tokens {
|
||||||
let text = sequence.append_token(*token).unwrap();
|
let text = sequence.append_token(*token).unwrap();
|
||||||
output.push_str(&text);
|
output.push_str(&text);
|
||||||
}
|
}
|
||||||
@@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(1000);
|
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(1000);
|
||||||
let all_tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
|
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||||
|
let all_tokens = encoding.token_ids();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("streaming_100k");
|
let mut group = c.benchmark_group("streaming_100k");
|
||||||
group.measurement_time(Duration::from_secs(1));
|
group.measurement_time(Duration::from_secs(1));
|
||||||
@@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
|||||||
group.bench_function("decode_stream_100k", |b| {
|
group.bench_function("decode_stream_100k", |b| {
|
||||||
let printed = printed_stream.clone();
|
let printed = printed_stream.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = all_tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|_iters| {
|
b.iter_custom(|_iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
|||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
let mut tokens_processed = 0u64;
|
let mut tokens_processed = 0u64;
|
||||||
|
|
||||||
for token in tokens.iter().cycle() {
|
for token in all_tokens.iter().cycle() {
|
||||||
if start.elapsed() >= Duration::from_millis(500) {
|
if start.elapsed() >= Duration::from_millis(500) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
|||||||
group.bench_function("sequence_100k", |b| {
|
group.bench_function("sequence_100k", |b| {
|
||||||
let printed = printed_seq.clone();
|
let printed = printed_seq.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = all_tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|_iters| {
|
b.iter_custom(|_iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
|
|||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
let mut tokens_processed = 0u64;
|
let mut tokens_processed = 0u64;
|
||||||
|
|
||||||
for token in tokens.iter().cycle() {
|
for token in all_tokens.iter().cycle() {
|
||||||
if start.elapsed() >= Duration::from_millis(500) {
|
if start.elapsed() >= Duration::from_millis(500) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) {
|
|||||||
let tokens_per_sequence = 10_000;
|
let tokens_per_sequence = 10_000;
|
||||||
|
|
||||||
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
||||||
let token_batch = tokenizer.encode(&sample_text).unwrap().token_ids();
|
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||||
|
let token_batch: Vec<u32> = encoding.token_ids().to_vec();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("concurrent_streaming");
|
let mut group = c.benchmark_group("concurrent_streaming");
|
||||||
group.measurement_time(Duration::from_secs(2));
|
group.measurement_time(Duration::from_secs(2));
|
||||||
@@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
|||||||
.with_stop_token(2);
|
.with_stop_token(2);
|
||||||
|
|
||||||
let sample_text = "Hello world! This is a test. ### Stop here. Continue after.".repeat(100);
|
let sample_text = "Hello world! This is a test. ### Stop here. Continue after.".repeat(100);
|
||||||
let tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
|
let encoding = tokenizer.encode(&sample_text).unwrap();
|
||||||
|
let tokens = encoding.token_ids();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("stop_sequences");
|
let mut group = c.benchmark_group("stop_sequences");
|
||||||
|
|
||||||
@@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
|||||||
group.bench_function("no_stops", |b| {
|
group.bench_function("no_stops", |b| {
|
||||||
let printed_clone = printed_no_stop.clone();
|
let printed_clone = printed_no_stop.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = tokens.clone();
|
|
||||||
|
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
|||||||
StopSequenceConfig::default(),
|
StopSequenceConfig::default(),
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
for token in &tokens {
|
for token in tokens {
|
||||||
let _ = decoder.process_token(*token).unwrap();
|
let _ = decoder.process_token(*token).unwrap();
|
||||||
total_tokens += 1;
|
total_tokens += 1;
|
||||||
}
|
}
|
||||||
@@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
|||||||
group.bench_function("with_stops", |b| {
|
group.bench_function("with_stops", |b| {
|
||||||
let printed_clone = printed_with_stops.clone();
|
let printed_clone = printed_with_stops.clone();
|
||||||
let tokenizer = tokenizer.clone();
|
let tokenizer = tokenizer.clone();
|
||||||
let tokens = tokens.clone();
|
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
|
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
@@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
|
|||||||
StopSequenceDecoder::new(tokenizer.clone(), config.clone(), false);
|
StopSequenceDecoder::new(tokenizer.clone(), config.clone(), false);
|
||||||
let mut sequence_tokens = 0u64;
|
let mut sequence_tokens = 0u64;
|
||||||
|
|
||||||
for token in &tokens {
|
for token in tokens {
|
||||||
let result = decoder.process_token(*token).unwrap();
|
let result = decoder.process_token(*token).unwrap();
|
||||||
sequence_tokens += 1;
|
sequence_tokens += 1;
|
||||||
|
|
||||||
@@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) {
|
|||||||
|
|
||||||
// Generate tokens for decoding
|
// Generate tokens for decoding
|
||||||
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
|
||||||
let test_tokens = tokenizer.encode(&test_text).unwrap().token_ids();
|
let encoding = tokenizer.encode(&test_text).unwrap();
|
||||||
|
let test_tokens: Vec<u32> = encoding.token_ids().to_vec();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("multithreaded_decode");
|
let mut group = c.benchmark_group("multithreaded_decode");
|
||||||
group.measurement_time(Duration::from_secs(2));
|
group.measurement_time(Duration::from_secs(2));
|
||||||
@@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) {
|
|||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _ in 0..iters {
|
for _ in 0..iters {
|
||||||
let _ = black_box(encoding.token_ids_ref());
|
let _ = black_box(encoding.token_ids());
|
||||||
}
|
}
|
||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ pub mod routers;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
pub mod service_discovery;
|
pub mod service_discovery;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
|
pub mod tool_parser;
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
use crate::metrics::PrometheusConfig;
|
use crate::metrics::PrometheusConfig;
|
||||||
|
|
||||||
|
|||||||
32
sgl-router/src/tool_parser/errors.rs
Normal file
32
sgl-router/src/tool_parser/errors.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// Result type for tool parser operations
|
||||||
|
pub type ToolParserResult<T> = Result<T, ToolParserError>;
|
||||||
|
|
||||||
|
/// Errors that can occur during tool parsing
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ToolParserError {
|
||||||
|
#[error("Parsing failed: {0}")]
|
||||||
|
ParsingFailed(String),
|
||||||
|
|
||||||
|
#[error("Model not supported: {0}")]
|
||||||
|
ModelNotSupported(String),
|
||||||
|
|
||||||
|
#[error("Parse depth exceeded: max {0}")]
|
||||||
|
DepthExceeded(usize),
|
||||||
|
|
||||||
|
#[error("Invalid JSON: {0}")]
|
||||||
|
JsonError(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("Regex error: {0}")]
|
||||||
|
RegexError(#[from] regex::Error),
|
||||||
|
|
||||||
|
#[error("Incomplete tool call")]
|
||||||
|
Incomplete,
|
||||||
|
|
||||||
|
#[error("Invalid tool name: {0}")]
|
||||||
|
InvalidToolName(String),
|
||||||
|
|
||||||
|
#[error("Token not found: {0}")]
|
||||||
|
TokenNotFound(String),
|
||||||
|
}
|
||||||
20
sgl-router/src/tool_parser/mod.rs
Normal file
20
sgl-router/src/tool_parser/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
/// Tool parser module for handling function/tool calls in model outputs
|
||||||
|
///
|
||||||
|
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||||
|
/// Phase 1 focuses on core infrastructure: types, traits, registry, and partial JSON parsing.
|
||||||
|
pub mod errors;
|
||||||
|
pub mod partial_json;
|
||||||
|
pub mod registry;
|
||||||
|
pub mod state;
|
||||||
|
pub mod traits;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
// Re-export commonly used types
|
||||||
|
pub use errors::{ToolParserError, ToolParserResult};
|
||||||
|
pub use registry::ParserRegistry;
|
||||||
|
pub use state::{ParsePhase, ParseState};
|
||||||
|
pub use traits::{PartialJsonParser, ToolParser};
|
||||||
|
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
|
||||||
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
527
sgl-router/src/tool_parser/partial_json.rs
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
use crate::tool_parser::{
|
||||||
|
errors::{ToolParserError, ToolParserResult},
|
||||||
|
traits::PartialJsonParser,
|
||||||
|
};
|
||||||
|
use serde_json::{Map, Value};
|
||||||
|
|
||||||
|
/// Parser for incomplete JSON
|
||||||
|
pub struct PartialJson {
|
||||||
|
/// Maximum depth for nested structures
|
||||||
|
max_depth: usize,
|
||||||
|
/// Whether to allow incomplete values
|
||||||
|
allow_incomplete: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialJson {
|
||||||
|
/// Create a new partial JSON parser
|
||||||
|
pub fn new(max_depth: usize, allow_incomplete: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
max_depth,
|
||||||
|
allow_incomplete,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
|
||||||
|
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||||
|
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
|
||||||
|
let value = parser.parse_value(0)?;
|
||||||
|
Ok((value, parser.position))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PartialJson {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(32, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialJsonParser for PartialJson {
|
||||||
|
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
|
||||||
|
self.parse_value(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_complete(&self, input: &str) -> bool {
|
||||||
|
// Try to parse as complete JSON
|
||||||
|
serde_json::from_str::<Value>(input).is_ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_depth(&self) -> usize {
|
||||||
|
self.max_depth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal parser state
|
||||||
|
struct Parser<'a> {
|
||||||
|
chars: std::iter::Peekable<std::str::Chars<'a>>,
|
||||||
|
position: usize,
|
||||||
|
max_depth: usize,
|
||||||
|
allow_incomplete: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Parser<'a> {
|
||||||
|
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
chars: input.chars().peekable(),
|
||||||
|
position: 0,
|
||||||
|
max_depth,
|
||||||
|
allow_incomplete,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peek(&mut self) -> Option<char> {
|
||||||
|
self.chars.peek().copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn advance(&mut self) {
|
||||||
|
if self.chars.next().is_some() {
|
||||||
|
self.position += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn skip_whitespace(&mut self) {
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_whitespace() {
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||||
|
if depth > self.max_depth {
|
||||||
|
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
match self.peek() {
|
||||||
|
Some('{') => self.parse_object(depth + 1),
|
||||||
|
Some('[') => self.parse_array(depth + 1),
|
||||||
|
Some('"') => self.parse_string(),
|
||||||
|
Some('t') | Some('f') => self.parse_bool(),
|
||||||
|
Some('n') => self.parse_null(),
|
||||||
|
Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(),
|
||||||
|
_ => {
|
||||||
|
if self.allow_incomplete {
|
||||||
|
Ok(Value::Null)
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed(
|
||||||
|
"Unexpected character".into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||||
|
if depth > self.max_depth {
|
||||||
|
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut object = Map::new();
|
||||||
|
|
||||||
|
// Consume '{'
|
||||||
|
self.advance();
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
// Check for empty object
|
||||||
|
if self.peek() == Some('}') {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Parse key
|
||||||
|
let key = match self.parse_string() {
|
||||||
|
Ok(Value::String(s)) => s,
|
||||||
|
Err(_) if self.allow_incomplete => {
|
||||||
|
// Incomplete object
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
// Expect ':'
|
||||||
|
if self.peek() != Some(':') {
|
||||||
|
if self.allow_incomplete {
|
||||||
|
// Add null value for incomplete pair
|
||||||
|
object.insert(key, Value::Null);
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
|
||||||
|
}
|
||||||
|
self.advance();
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
// Parse value (keep same depth - we already incremented in parse_object)
|
||||||
|
let value = match self.parse_value(depth) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) if self.allow_incomplete => {
|
||||||
|
// Add null for incomplete value
|
||||||
|
object.insert(key, Value::Null);
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
object.insert(key, value);
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
match self.peek() {
|
||||||
|
Some(',') => {
|
||||||
|
self.advance();
|
||||||
|
self.skip_whitespace();
|
||||||
|
// Check for trailing comma
|
||||||
|
if self.peek() == Some('}') {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some('}') => {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
None if self.allow_incomplete => {
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if self.allow_incomplete {
|
||||||
|
return Ok(Value::Object(object));
|
||||||
|
}
|
||||||
|
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
|
||||||
|
if depth > self.max_depth {
|
||||||
|
return Err(ToolParserError::DepthExceeded(self.max_depth));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut array = Vec::new();
|
||||||
|
|
||||||
|
// Consume '['
|
||||||
|
self.advance();
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
// Check for empty array
|
||||||
|
if self.peek() == Some(']') {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Parse value (keep same depth - we already incremented in parse_object)
|
||||||
|
let value = match self.parse_value(depth) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) if self.allow_incomplete => {
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
array.push(value);
|
||||||
|
self.skip_whitespace();
|
||||||
|
|
||||||
|
match self.peek() {
|
||||||
|
Some(',') => {
|
||||||
|
self.advance();
|
||||||
|
self.skip_whitespace();
|
||||||
|
// Check for trailing comma
|
||||||
|
if self.peek() == Some(']') {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(']') => {
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
None if self.allow_incomplete => {
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if self.allow_incomplete {
|
||||||
|
return Ok(Value::Array(array));
|
||||||
|
}
|
||||||
|
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_string(&mut self) -> ToolParserResult<Value> {
|
||||||
|
if self.peek() != Some('"') {
|
||||||
|
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume opening quote
|
||||||
|
self.advance();
|
||||||
|
|
||||||
|
let mut string = String::new();
|
||||||
|
let mut escaped = false;
|
||||||
|
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if escaped {
|
||||||
|
// Handle escape sequences
|
||||||
|
let escaped_char = match ch {
|
||||||
|
'"' | '\\' | '/' => ch,
|
||||||
|
'b' => '\u{0008}',
|
||||||
|
'f' => '\u{000C}',
|
||||||
|
'n' => '\n',
|
||||||
|
'r' => '\r',
|
||||||
|
't' => '\t',
|
||||||
|
'u' => {
|
||||||
|
// Unicode escape
|
||||||
|
self.advance();
|
||||||
|
let hex = self.parse_unicode_escape()?;
|
||||||
|
string.push(hex);
|
||||||
|
escaped = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => ch, // Invalid escape, but be lenient
|
||||||
|
};
|
||||||
|
string.push(escaped_char);
|
||||||
|
escaped = false;
|
||||||
|
} else if ch == '\\' {
|
||||||
|
escaped = true;
|
||||||
|
} else if ch == '"' {
|
||||||
|
// End of string
|
||||||
|
self.advance();
|
||||||
|
return Ok(Value::String(string));
|
||||||
|
} else {
|
||||||
|
string.push(ch);
|
||||||
|
}
|
||||||
|
self.advance();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incomplete string
|
||||||
|
if self.allow_incomplete {
|
||||||
|
Ok(Value::String(string))
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
|
||||||
|
let mut hex = String::new();
|
||||||
|
for _ in 0..4 {
|
||||||
|
if let Some(ch) = self.peek() {
|
||||||
|
if ch.is_ascii_hexdigit() {
|
||||||
|
hex.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hex.len() == 4 {
|
||||||
|
u32::from_str_radix(&hex, 16)
|
||||||
|
.ok()
|
||||||
|
.and_then(char::from_u32)
|
||||||
|
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
|
||||||
|
} else if self.allow_incomplete {
|
||||||
|
Ok('\u{FFFD}') // Replacement character
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed(
|
||||||
|
"Incomplete unicode escape".into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_number(&mut self) -> ToolParserResult<Value> {
|
||||||
|
let mut number = String::new();
|
||||||
|
|
||||||
|
// Handle negative sign
|
||||||
|
if self.peek() == Some('-') {
|
||||||
|
number.push('-');
|
||||||
|
self.advance();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse integer part
|
||||||
|
if self.peek() == Some('0') {
|
||||||
|
number.push('0');
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_ascii_digit() {
|
||||||
|
number.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse decimal part
|
||||||
|
if self.peek() == Some('.') {
|
||||||
|
number.push('.');
|
||||||
|
self.advance();
|
||||||
|
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_ascii_digit() {
|
||||||
|
number.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse exponent
|
||||||
|
if let Some(ch) = self.peek() {
|
||||||
|
if ch == 'e' || ch == 'E' {
|
||||||
|
number.push(ch);
|
||||||
|
self.advance();
|
||||||
|
|
||||||
|
if let Some(sign) = self.peek() {
|
||||||
|
if sign == '+' || sign == '-' {
|
||||||
|
number.push(sign);
|
||||||
|
self.advance();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_ascii_digit() {
|
||||||
|
number.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as integer first, then as float
|
||||||
|
if let Ok(n) = number.parse::<i64>() {
|
||||||
|
Ok(Value::Number(serde_json::Number::from(n)))
|
||||||
|
} else if let Ok(n) = number.parse::<f64>() {
|
||||||
|
Ok(Value::Number(
|
||||||
|
serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)),
|
||||||
|
))
|
||||||
|
} else if self.allow_incomplete {
|
||||||
|
Ok(Value::Number(serde_json::Number::from(0)))
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed("Invalid number".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_bool(&mut self) -> ToolParserResult<Value> {
|
||||||
|
let mut word = String::new();
|
||||||
|
|
||||||
|
// Peek at upcoming characters to validate it looks like a boolean
|
||||||
|
let mut temp_chars = self.chars.clone();
|
||||||
|
while let Some(&ch) = temp_chars.peek() {
|
||||||
|
if ch.is_alphabetic() && word.len() < 5 {
|
||||||
|
// "false" is 5 chars
|
||||||
|
word.push(ch);
|
||||||
|
temp_chars.next();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a valid boolean prefix
|
||||||
|
let is_valid = word == "true"
|
||||||
|
|| word == "false"
|
||||||
|
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
|
||||||
|
|
||||||
|
if !is_valid {
|
||||||
|
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now actually consume the characters
|
||||||
|
word.clear();
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_alphabetic() {
|
||||||
|
word.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match word.as_str() {
|
||||||
|
"true" => Ok(Value::Bool(true)),
|
||||||
|
"false" => Ok(Value::Bool(false)),
|
||||||
|
partial if self.allow_incomplete => {
|
||||||
|
if "true".starts_with(partial) {
|
||||||
|
Ok(Value::Bool(true))
|
||||||
|
} else if "false".starts_with(partial) {
|
||||||
|
Ok(Value::Bool(false))
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_null(&mut self) -> ToolParserResult<Value> {
|
||||||
|
let mut word = String::new();
|
||||||
|
|
||||||
|
// Peek at upcoming characters to validate it looks like "null"
|
||||||
|
let mut temp_chars = self.chars.clone();
|
||||||
|
while let Some(&ch) = temp_chars.peek() {
|
||||||
|
if ch.is_alphabetic() && word.len() < 4 {
|
||||||
|
// "null" is 4 chars
|
||||||
|
word.push(ch);
|
||||||
|
temp_chars.next();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a valid null prefix
|
||||||
|
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
|
||||||
|
|
||||||
|
if !is_valid {
|
||||||
|
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now actually consume the characters
|
||||||
|
word.clear();
|
||||||
|
while let Some(ch) = self.peek() {
|
||||||
|
if ch.is_alphabetic() {
|
||||||
|
word.push(ch);
|
||||||
|
self.advance();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
|
||||||
|
Ok(Value::Null)
|
||||||
|
} else {
|
||||||
|
Err(ToolParserError::ParsingFailed("Invalid null".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility function to check if a string contains complete JSON
|
||||||
|
pub fn is_complete_json(input: &str) -> bool {
|
||||||
|
serde_json::from_str::<Value>(input).is_ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility function to find common prefix between two strings
|
||||||
|
pub fn find_common_prefix(s1: &str, s2: &str) -> usize {
|
||||||
|
s1.chars()
|
||||||
|
.zip(s2.chars())
|
||||||
|
.take_while(|(a, b)| a == b)
|
||||||
|
.count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility function to compute diff between old and new strings
|
||||||
|
pub fn compute_diff(old: &str, new: &str) -> String {
|
||||||
|
let common_len = find_common_prefix(old, new);
|
||||||
|
// Convert character count to byte offset
|
||||||
|
new.chars().skip(common_len).collect()
|
||||||
|
}
|
||||||
119
sgl-router/src/tool_parser/registry.rs
Normal file
119
sgl-router/src/tool_parser/registry.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
use crate::tool_parser::traits::ToolParser;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Registry for tool parsers and model mappings
|
||||||
|
pub struct ParserRegistry {
|
||||||
|
/// Map of parser name to parser instance
|
||||||
|
parsers: HashMap<String, Arc<dyn ToolParser>>,
|
||||||
|
/// Map of model name/pattern to parser name
|
||||||
|
model_mapping: HashMap<String, String>,
|
||||||
|
/// Default parser to use when no match found
|
||||||
|
default_parser: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParserRegistry {
|
||||||
|
/// Create a new parser registry with default mappings
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let mut registry = Self {
|
||||||
|
parsers: HashMap::new(),
|
||||||
|
model_mapping: HashMap::new(),
|
||||||
|
default_parser: "json".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register default model mappings
|
||||||
|
registry.register_default_mappings();
|
||||||
|
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a parser
|
||||||
|
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
|
||||||
|
self.parsers.insert(name.into(), parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map a model name/pattern to a parser
|
||||||
|
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
|
||||||
|
self.model_mapping.insert(model.into(), parser.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get parser for a specific model
|
||||||
|
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
|
||||||
|
// Try exact match first
|
||||||
|
if let Some(parser_name) = self.model_mapping.get(model) {
|
||||||
|
if let Some(parser) = self.parsers.get(parser_name) {
|
||||||
|
return Some(parser.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
|
||||||
|
for (pattern, parser_name) in &self.model_mapping {
|
||||||
|
if pattern.ends_with('*') {
|
||||||
|
let prefix = &pattern[..pattern.len() - 1];
|
||||||
|
if model.starts_with(prefix) {
|
||||||
|
if let Some(parser) = self.parsers.get(parser_name) {
|
||||||
|
return Some(parser.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to default parser if it exists
|
||||||
|
self.parsers.get(&self.default_parser).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all registered parsers
|
||||||
|
pub fn list_parsers(&self) -> Vec<&str> {
|
||||||
|
self.parsers.keys().map(|s| s.as_str()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all model mappings
|
||||||
|
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
|
||||||
|
self.model_mapping
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register default model mappings
|
||||||
|
fn register_default_mappings(&mut self) {
|
||||||
|
// OpenAI models
|
||||||
|
self.map_model("gpt-4*", "json");
|
||||||
|
self.map_model("gpt-3.5*", "json");
|
||||||
|
self.map_model("gpt-4o*", "json");
|
||||||
|
|
||||||
|
// Anthropic models
|
||||||
|
self.map_model("claude-*", "json");
|
||||||
|
|
||||||
|
// Mistral models
|
||||||
|
self.map_model("mistral-*", "mistral");
|
||||||
|
self.map_model("mixtral-*", "mistral");
|
||||||
|
|
||||||
|
// Qwen models
|
||||||
|
self.map_model("qwen*", "qwen");
|
||||||
|
|
||||||
|
// Llama models
|
||||||
|
self.map_model("llama-*", "llama");
|
||||||
|
self.map_model("meta-llama-*", "llama");
|
||||||
|
|
||||||
|
// Other models default to JSON
|
||||||
|
self.map_model("gemini-*", "json");
|
||||||
|
self.map_model("palm-*", "json");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the default parser
|
||||||
|
pub fn set_default_parser(&mut self, name: impl Into<String>) {
|
||||||
|
self.default_parser = name.into();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a parser is registered
|
||||||
|
pub fn has_parser(&self, name: &str) -> bool {
|
||||||
|
self.parsers.contains_key(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ParserRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
181
sgl-router/src/tool_parser/state.rs
Normal file
181
sgl-router/src/tool_parser/state.rs
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
use crate::tool_parser::types::{PartialToolCall, ToolCall};
|
||||||
|
|
||||||
|
/// Current phase of parsing
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ParsePhase {
|
||||||
|
/// Looking for start of tool call
|
||||||
|
Searching,
|
||||||
|
/// Parsing function name
|
||||||
|
InName,
|
||||||
|
/// Parsing function arguments
|
||||||
|
InArguments,
|
||||||
|
/// Tool call complete
|
||||||
|
Complete,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State for streaming parser
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ParseState {
|
||||||
|
/// Buffer for accumulating input
|
||||||
|
pub buffer: String,
|
||||||
|
/// Position of last consumed character
|
||||||
|
pub consumed: usize,
|
||||||
|
/// Current partial tool being parsed
|
||||||
|
pub partial_tool: Option<PartialToolCall>,
|
||||||
|
/// Completed tool calls
|
||||||
|
pub completed_tools: Vec<ToolCall>,
|
||||||
|
/// Current parsing phase
|
||||||
|
pub phase: ParsePhase,
|
||||||
|
/// Bracket/brace depth for JSON parsing
|
||||||
|
pub bracket_depth: i32,
|
||||||
|
/// Whether currently inside a string literal
|
||||||
|
pub in_string: bool,
|
||||||
|
/// Whether next character should be escaped
|
||||||
|
pub escape_next: bool,
|
||||||
|
/// Current tool index (for streaming)
|
||||||
|
pub tool_index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParseState {
|
||||||
|
/// Create a new parse state
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: String::new(),
|
||||||
|
consumed: 0,
|
||||||
|
partial_tool: None,
|
||||||
|
completed_tools: Vec::new(),
|
||||||
|
phase: ParsePhase::Searching,
|
||||||
|
bracket_depth: 0,
|
||||||
|
in_string: false,
|
||||||
|
escape_next: false,
|
||||||
|
tool_index: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset state for parsing next tool
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.partial_tool = None;
|
||||||
|
self.phase = ParsePhase::Searching;
|
||||||
|
self.bracket_depth = 0;
|
||||||
|
self.in_string = false;
|
||||||
|
self.escape_next = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a single character for JSON parsing
|
||||||
|
pub fn process_char(&mut self, ch: char) {
|
||||||
|
// Handle escape sequences
|
||||||
|
if self.escape_next {
|
||||||
|
self.escape_next = false;
|
||||||
|
self.buffer.push(ch);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch == '\\' && self.in_string {
|
||||||
|
self.escape_next = true;
|
||||||
|
self.buffer.push(ch);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track string boundaries
|
||||||
|
if ch == '"' && !self.escape_next {
|
||||||
|
self.in_string = !self.in_string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track bracket depth for JSON
|
||||||
|
if !self.in_string {
|
||||||
|
match ch {
|
||||||
|
'{' | '[' => {
|
||||||
|
self.bracket_depth += 1;
|
||||||
|
}
|
||||||
|
'}' | ']' => {
|
||||||
|
self.bracket_depth -= 1;
|
||||||
|
if self.bracket_depth == 0 && self.partial_tool.is_some() {
|
||||||
|
// Complete tool call found
|
||||||
|
self.phase = ParsePhase::Complete;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.buffer.push(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if we have a complete JSON object/array
|
||||||
|
pub fn has_complete_json(&self) -> bool {
|
||||||
|
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract content from buffer starting at position
|
||||||
|
pub fn extract_from(&self, start: usize) -> &str {
|
||||||
|
if start >= self.buffer.len() {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the nearest character boundary at or after start
|
||||||
|
let mut safe_start = start;
|
||||||
|
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
|
||||||
|
safe_start += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if safe_start < self.buffer.len() {
|
||||||
|
&self.buffer[safe_start..]
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark content as consumed up to position
|
||||||
|
pub fn consume_to(&mut self, position: usize) {
|
||||||
|
if position > self.consumed {
|
||||||
|
self.consumed = position;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get unconsumed content
|
||||||
|
pub fn unconsumed(&self) -> &str {
|
||||||
|
if self.consumed >= self.buffer.len() {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the nearest character boundary at or after consumed
|
||||||
|
let mut safe_consumed = self.consumed;
|
||||||
|
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
|
||||||
|
safe_consumed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if safe_consumed < self.buffer.len() {
|
||||||
|
&self.buffer[safe_consumed..]
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear consumed content from buffer
|
||||||
|
pub fn clear_consumed(&mut self) {
|
||||||
|
if self.consumed > 0 {
|
||||||
|
// Find the nearest character boundary at or before consumed
|
||||||
|
let mut safe_consumed = self.consumed;
|
||||||
|
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
|
||||||
|
safe_consumed -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if safe_consumed > 0 {
|
||||||
|
self.buffer.drain(..safe_consumed);
|
||||||
|
self.consumed = self.consumed.saturating_sub(safe_consumed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add completed tool
|
||||||
|
pub fn add_completed_tool(&mut self, tool: ToolCall) {
|
||||||
|
self.completed_tools.push(tool);
|
||||||
|
self.tool_index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ParseState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
249
sgl-router/src/tool_parser/tests.rs
Normal file
249
sgl-router/src/tool_parser/tests.rs
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
use super::*;
|
||||||
|
use crate::tool_parser::partial_json::{
|
||||||
|
compute_diff, find_common_prefix, is_complete_json, PartialJson,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_state_new() {
|
||||||
|
let state = ParseState::new();
|
||||||
|
assert_eq!(state.phase, ParsePhase::Searching);
|
||||||
|
assert_eq!(state.buffer, "");
|
||||||
|
assert_eq!(state.consumed, 0);
|
||||||
|
assert_eq!(state.bracket_depth, 0);
|
||||||
|
assert!(!state.in_string);
|
||||||
|
assert!(!state.escape_next);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_state_process_char() {
|
||||||
|
let mut state = ParseState::new();
|
||||||
|
|
||||||
|
// Test bracket tracking
|
||||||
|
state.process_char('{');
|
||||||
|
assert_eq!(state.bracket_depth, 1);
|
||||||
|
|
||||||
|
state.process_char('}');
|
||||||
|
assert_eq!(state.bracket_depth, 0);
|
||||||
|
|
||||||
|
// Test string tracking
|
||||||
|
state.process_char('"');
|
||||||
|
assert!(state.in_string);
|
||||||
|
|
||||||
|
state.process_char('"');
|
||||||
|
assert!(!state.in_string);
|
||||||
|
|
||||||
|
// Test escape handling
|
||||||
|
state.process_char('"');
|
||||||
|
state.process_char('\\');
|
||||||
|
assert!(state.escape_next);
|
||||||
|
|
||||||
|
state.process_char('"');
|
||||||
|
assert!(!state.escape_next);
|
||||||
|
assert!(state.in_string); // Still in string because quote was escaped
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_config() {
|
||||||
|
let config = TokenConfig {
|
||||||
|
start_tokens: vec!["<start>".to_string(), "[".to_string()],
|
||||||
|
end_tokens: vec!["</end>".to_string(), "]".to_string()],
|
||||||
|
separator: ", ".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let pairs: Vec<_> = config.iter_pairs().collect();
|
||||||
|
assert_eq!(pairs.len(), 2);
|
||||||
|
assert_eq!(pairs[0], ("<start>", "</end>"));
|
||||||
|
assert_eq!(pairs[1], ("[", "]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parser_registry() {
|
||||||
|
let registry = ParserRegistry::new();
|
||||||
|
|
||||||
|
// Test has default mappings
|
||||||
|
assert!(!registry.list_mappings().is_empty());
|
||||||
|
|
||||||
|
// Test model pattern matching
|
||||||
|
let mappings = registry.list_mappings();
|
||||||
|
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||||
|
assert!(has_gpt);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parser_registry_pattern_matching() {
|
||||||
|
let mut registry = ParserRegistry::new();
|
||||||
|
|
||||||
|
// Test that model mappings work by checking the list
|
||||||
|
registry.map_model("test-model", "json");
|
||||||
|
|
||||||
|
// Verify through list_mappings
|
||||||
|
let mappings = registry.list_mappings();
|
||||||
|
let has_test = mappings
|
||||||
|
.iter()
|
||||||
|
.any(|(m, p)| *m == "test-model" && *p == "json");
|
||||||
|
assert!(has_test);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_serialization() {
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
id: "call-123".to_string(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: "search".to_string(),
|
||||||
|
arguments: r#"{"query": "rust programming"}"#.to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&tool_call).unwrap();
|
||||||
|
assert!(json.contains("call-123"));
|
||||||
|
assert!(json.contains("search"));
|
||||||
|
assert!(json.contains("rust programming"));
|
||||||
|
|
||||||
|
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(parsed.id, "call-123");
|
||||||
|
assert_eq!(parsed.function.name, "search");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_json_parser() {
|
||||||
|
let parser = PartialJson::default();
|
||||||
|
|
||||||
|
// Test complete JSON
|
||||||
|
let input = r#"{"name": "test", "value": 42}"#;
|
||||||
|
let (value, consumed) = parser.parse_value(input).unwrap();
|
||||||
|
assert_eq!(value["name"], "test");
|
||||||
|
assert_eq!(value["value"], 42);
|
||||||
|
assert_eq!(consumed, input.len());
|
||||||
|
|
||||||
|
// Test incomplete JSON object
|
||||||
|
let input = r#"{"name": "test", "value": "#;
|
||||||
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
|
assert_eq!(value["name"], "test");
|
||||||
|
assert!(value["value"].is_null());
|
||||||
|
|
||||||
|
// Test incomplete string
|
||||||
|
let input = r#"{"name": "tes"#;
|
||||||
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
|
assert_eq!(value["name"], "tes");
|
||||||
|
|
||||||
|
// Test incomplete array
|
||||||
|
let input = r#"[1, 2, "#;
|
||||||
|
let (value, _consumed) = parser.parse_value(input).unwrap();
|
||||||
|
assert!(value.is_array());
|
||||||
|
assert_eq!(value[0], 1);
|
||||||
|
assert_eq!(value[1], 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_json_depth_limit() {
|
||||||
|
// max_depth of 3 allows nesting up to 3 levels
|
||||||
|
// Set allow_incomplete to false to get errors instead of partial results
|
||||||
|
let parser = PartialJson::new(3, false);
|
||||||
|
|
||||||
|
// This should work (simple object)
|
||||||
|
let input = r#"{"a": 1}"#;
|
||||||
|
let result = parser.parse_value(input);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// This should work (nested to depth 3)
|
||||||
|
let input = r#"{"a": {"b": {"c": 1}}}"#;
|
||||||
|
let result = parser.parse_value(input);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// This should fail (nested to depth 4, exceeds limit)
|
||||||
|
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
|
||||||
|
let result = parser.parse_value(input);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_complete_json() {
|
||||||
|
assert!(is_complete_json(r#"{"name": "test"}"#));
|
||||||
|
assert!(is_complete_json(r#"[1, 2, 3]"#));
|
||||||
|
assert!(is_complete_json(r#""string""#));
|
||||||
|
assert!(is_complete_json("42"));
|
||||||
|
assert!(is_complete_json("true"));
|
||||||
|
assert!(is_complete_json("null"));
|
||||||
|
|
||||||
|
assert!(!is_complete_json(r#"{"name": "#));
|
||||||
|
assert!(!is_complete_json(r#"[1, 2, "#));
|
||||||
|
assert!(!is_complete_json(r#""unclosed"#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_common_prefix() {
|
||||||
|
assert_eq!(find_common_prefix("hello", "hello"), 5);
|
||||||
|
assert_eq!(find_common_prefix("hello", "help"), 3);
|
||||||
|
assert_eq!(find_common_prefix("hello", "world"), 0);
|
||||||
|
assert_eq!(find_common_prefix("", "hello"), 0);
|
||||||
|
assert_eq!(find_common_prefix("hello", ""), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compute_diff() {
|
||||||
|
assert_eq!(compute_diff("hello", "hello world"), " world");
|
||||||
|
assert_eq!(compute_diff("", "hello"), "hello");
|
||||||
|
assert_eq!(compute_diff("hello", "hello"), "");
|
||||||
|
assert_eq!(compute_diff("test", "hello"), "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_result_variants() {
|
||||||
|
// Test Incomplete
|
||||||
|
let result = StreamResult::Incomplete;
|
||||||
|
matches!(result, StreamResult::Incomplete);
|
||||||
|
|
||||||
|
// Test ToolName
|
||||||
|
let result = StreamResult::ToolName {
|
||||||
|
index: 0,
|
||||||
|
name: "test".to_string(),
|
||||||
|
};
|
||||||
|
if let StreamResult::ToolName { index, name } = result {
|
||||||
|
assert_eq!(index, 0);
|
||||||
|
assert_eq!(name, "test");
|
||||||
|
} else {
|
||||||
|
panic!("Expected ToolName variant");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ToolComplete
|
||||||
|
let tool = ToolCall {
|
||||||
|
id: "123".to_string(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: "test".to_string(),
|
||||||
|
arguments: "{}".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let result = StreamResult::ToolComplete(tool.clone());
|
||||||
|
if let StreamResult::ToolComplete(t) = result {
|
||||||
|
assert_eq!(t.id, "123");
|
||||||
|
} else {
|
||||||
|
panic!("Expected ToolComplete variant");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_tool_call() {
|
||||||
|
let mut partial = PartialToolCall {
|
||||||
|
name: None,
|
||||||
|
arguments_buffer: String::new(),
|
||||||
|
start_position: 0,
|
||||||
|
name_sent: false,
|
||||||
|
streamed_args: String::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set name
|
||||||
|
partial.name = Some("test_function".to_string());
|
||||||
|
assert_eq!(partial.name.as_ref().unwrap(), "test_function");
|
||||||
|
|
||||||
|
// Append arguments
|
||||||
|
partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
|
||||||
|
assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);
|
||||||
|
|
||||||
|
// Update streaming state
|
||||||
|
partial.name_sent = true;
|
||||||
|
partial.streamed_args = r#"{"key": "#.to_string();
|
||||||
|
assert!(partial.name_sent);
|
||||||
|
assert_eq!(partial.streamed_args, r#"{"key": "#);
|
||||||
|
}
|
||||||
35
sgl-router/src/tool_parser/traits.rs
Normal file
35
sgl-router/src/tool_parser/traits.rs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
use crate::tool_parser::{
|
||||||
|
errors::ToolParserResult,
|
||||||
|
state::ParseState,
|
||||||
|
types::{StreamResult, ToolCall},
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// Core trait for all tool parsers
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ToolParser: Send + Sync {
|
||||||
|
/// Parse complete tool calls from final output
|
||||||
|
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
|
||||||
|
|
||||||
|
/// Parse tool calls from model output (streaming)
|
||||||
|
async fn parse_incremental(
|
||||||
|
&self,
|
||||||
|
chunk: &str,
|
||||||
|
state: &mut ParseState,
|
||||||
|
) -> ToolParserResult<StreamResult>;
|
||||||
|
|
||||||
|
/// Check if text contains tool calls in this parser's format
|
||||||
|
fn detect_format(&self, text: &str) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for partial JSON parsing
|
||||||
|
pub trait PartialJsonParser: Send + Sync {
|
||||||
|
/// Parse potentially incomplete JSON
|
||||||
|
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
|
||||||
|
|
||||||
|
/// Check if JSON is complete
|
||||||
|
fn is_complete(&self, input: &str) -> bool;
|
||||||
|
|
||||||
|
/// Get the maximum parsing depth
|
||||||
|
fn max_depth(&self) -> usize;
|
||||||
|
}
|
||||||
73
sgl-router/src/tool_parser/types.rs
Normal file
73
sgl-router/src/tool_parser/types.rs
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Parsed tool call from model output (OpenAI format)
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
/// Unique identifier for the tool call
|
||||||
|
pub id: String,
|
||||||
|
/// Type of tool call (currently always "function")
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub r#type: String,
|
||||||
|
/// Function call details
|
||||||
|
pub function: FunctionCall,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function call within a tool call
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct FunctionCall {
|
||||||
|
/// Name of the function to call
|
||||||
|
pub name: String,
|
||||||
|
/// Arguments as JSON string
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming parse result
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StreamResult {
|
||||||
|
/// Need more data to continue parsing
|
||||||
|
Incomplete,
|
||||||
|
/// Found a tool name (for streaming)
|
||||||
|
ToolName { index: usize, name: String },
|
||||||
|
/// Found incremental arguments (for streaming)
|
||||||
|
ToolArguments { index: usize, arguments: String },
|
||||||
|
/// Completed parsing a tool
|
||||||
|
ToolComplete(ToolCall),
|
||||||
|
/// Normal text (not part of tool call)
|
||||||
|
NormalText(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token configuration for parsing
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TokenConfig {
|
||||||
|
/// Start tokens for tool calls
|
||||||
|
pub start_tokens: Vec<String>,
|
||||||
|
/// End tokens for tool calls
|
||||||
|
pub end_tokens: Vec<String>,
|
||||||
|
/// Separator between multiple tool calls
|
||||||
|
pub separator: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenConfig {
|
||||||
|
/// Iterate over start/end token pairs
|
||||||
|
pub fn iter_pairs(&self) -> impl Iterator<Item = (&str, &str)> {
|
||||||
|
self.start_tokens
|
||||||
|
.iter()
|
||||||
|
.zip(self.end_tokens.iter())
|
||||||
|
.map(|(s, e)| (s.as_str(), e.as_str()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple partial tool call for streaming
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PartialToolCall {
|
||||||
|
/// Tool name (if parsed)
|
||||||
|
pub name: Option<String>,
|
||||||
|
/// Buffer for accumulating arguments
|
||||||
|
pub arguments_buffer: String,
|
||||||
|
/// Start position in the input buffer
|
||||||
|
pub start_position: usize,
|
||||||
|
/// Whether the name has been sent (for streaming)
|
||||||
|
pub name_sent: bool,
|
||||||
|
/// Arguments already streamed
|
||||||
|
pub streamed_args: String,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user