[router][grpc] Support tool call parser in streaming (#11160)
This commit is contained in:
319
sgl-router/src/tool_parser/factory.rs
Normal file
319
sgl-router/src/tool_parser/factory.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
// Factory and pool for creating model-specific tool parsers with pooling support.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tool_parser::parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
|
||||
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
|
||||
|
||||
/// Type alias for parser creator functions.
|
||||
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
|
||||
|
||||
/// Registry for model-specific tool parsers with pooling support.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolParserRegistry {
|
||||
/// Creator functions for parsers (used when pool is empty)
|
||||
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
|
||||
/// Pooled parser instances for reuse
|
||||
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
|
||||
/// Model pattern to parser name mappings
|
||||
model_mapping: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Default parser name
|
||||
default_parser: Arc<RwLock<String>>,
|
||||
}
|
||||
|
||||
impl ToolParserRegistry {
|
||||
/// Create a new empty registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
creators: Arc::new(RwLock::new(HashMap::new())),
|
||||
pool: Arc::new(RwLock::new(HashMap::new())),
|
||||
model_mapping: Arc::new(RwLock::new(HashMap::new())),
|
||||
default_parser: Arc::new(RwLock::new("json".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a parser creator for a given parser type.
|
||||
pub fn register_parser<F>(&self, name: &str, creator: F)
|
||||
where
|
||||
F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
|
||||
{
|
||||
let mut creators = self.creators.write().unwrap();
|
||||
creators.insert(name.to_string(), Arc::new(creator));
|
||||
}
|
||||
|
||||
/// Map a model name/pattern to a parser
|
||||
pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
|
||||
let mut mapping = self.model_mapping.write().unwrap();
|
||||
mapping.insert(model.into(), parser.into());
|
||||
}
|
||||
|
||||
/// Get a pooled parser by exact name.
|
||||
/// Returns a shared parser instance from the pool, creating one if needed.
|
||||
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
|
||||
// First check if we have a pooled instance
|
||||
{
|
||||
let pool = self.pool.read().unwrap();
|
||||
if let Some(parser) = pool.get(name) {
|
||||
return Some(Arc::clone(parser));
|
||||
}
|
||||
}
|
||||
|
||||
// If not in pool, create one and add to pool
|
||||
let creators = self.creators.read().unwrap();
|
||||
if let Some(creator) = creators.get(name) {
|
||||
let parser = Arc::new(Mutex::new(creator()));
|
||||
|
||||
// Add to pool for future use
|
||||
let mut pool = self.pool.write().unwrap();
|
||||
pool.insert(name.to_string(), Arc::clone(&parser));
|
||||
|
||||
Some(parser)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get parser for a specific model
|
||||
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
|
||||
// Try exact match first
|
||||
{
|
||||
let mapping = self.model_mapping.read().unwrap();
|
||||
if let Some(parser_name) = mapping.get(model) {
|
||||
if let Some(parser) = self.get_pooled_parser(parser_name) {
|
||||
return Some(parser);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching with more specific patterns first
|
||||
let model_mapping = self.model_mapping.read().unwrap();
|
||||
let best_match = model_mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
|
||||
})
|
||||
.max_by_key(|(pattern, _)| pattern.len());
|
||||
|
||||
// Return the best matching parser
|
||||
if let Some((_, parser_name)) = best_match {
|
||||
if let Some(parser) = self.get_pooled_parser(parser_name) {
|
||||
return Some(parser);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default parser
|
||||
let default = self.default_parser.read().unwrap().clone();
|
||||
self.get_pooled_parser(&default)
|
||||
}
|
||||
|
||||
/// Clear the parser pool, forcing new instances to be created.
|
||||
pub fn clear_pool(&self) {
|
||||
let mut pool = self.pool.write().unwrap();
|
||||
pool.clear();
|
||||
}
|
||||
|
||||
/// Set the default parser
|
||||
pub fn set_default_parser(&self, name: impl Into<String>) {
|
||||
let mut default = self.default_parser.write().unwrap();
|
||||
*default = name.into();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolParserRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory for creating tool parsers based on model type.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolParserFactory {
|
||||
registry: ToolParserRegistry,
|
||||
}
|
||||
|
||||
impl ToolParserFactory {
|
||||
/// Create a new factory with default parsers registered.
|
||||
pub fn new() -> Self {
|
||||
let registry = ToolParserRegistry::new();
|
||||
|
||||
// Register default parsers
|
||||
registry.register_parser("json", || Box::new(JsonParser::new()));
|
||||
registry.register_parser("mistral", || Box::new(MistralParser::new()));
|
||||
registry.register_parser("qwen", || Box::new(QwenParser::new()));
|
||||
registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
|
||||
registry.register_parser("llama", || Box::new(LlamaParser::new()));
|
||||
registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
|
||||
registry.register_parser("glm4_moe", || Box::new(Glm4MoeParser::new()));
|
||||
registry.register_parser("step3", || Box::new(Step3Parser::new()));
|
||||
registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
|
||||
|
||||
// Register GPT-OSS parsers
|
||||
registry.register_parser("gpt_oss_legacy", || Box::new(GptOssParser::new()));
|
||||
registry.register_parser("gpt_oss_harmony", || Box::new(GptOssHarmonyParser::new()));
|
||||
|
||||
// Choose which GPT-OSS variant to use as default
|
||||
if use_harmony_gpt_oss() {
|
||||
registry.register_parser("gpt_oss", || Box::new(GptOssHarmonyParser::new()));
|
||||
} else {
|
||||
registry.register_parser("gpt_oss", || Box::new(GptOssParser::new()));
|
||||
}
|
||||
|
||||
// Register default model mappings
|
||||
Self::register_default_mappings(®istry);
|
||||
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
fn register_default_mappings(registry: &ToolParserRegistry) {
|
||||
// OpenAI models
|
||||
registry.map_model("gpt-4*", "json");
|
||||
registry.map_model("gpt-3.5*", "json");
|
||||
registry.map_model("gpt-4o*", "json");
|
||||
|
||||
// Anthropic models
|
||||
registry.map_model("claude-*", "json");
|
||||
|
||||
// Mistral models
|
||||
registry.map_model("mistral-*", "mistral");
|
||||
registry.map_model("mixtral-*", "mistral");
|
||||
|
||||
// Qwen models
|
||||
registry.map_model("qwen*", "qwen");
|
||||
registry.map_model("Qwen*", "qwen");
|
||||
|
||||
// Llama models
|
||||
registry.map_model("llama-4*", "pythonic");
|
||||
registry.map_model("meta-llama-4*", "pythonic");
|
||||
registry.map_model("llama-3.2*", "llama");
|
||||
registry.map_model("meta-llama-3.2*", "llama");
|
||||
registry.map_model("llama-*", "json");
|
||||
registry.map_model("meta-llama-*", "json");
|
||||
|
||||
// DeepSeek models
|
||||
registry.map_model("deepseek-v3*", "deepseek");
|
||||
registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
|
||||
registry.map_model("deepseek-*", "pythonic");
|
||||
|
||||
// GLM models
|
||||
registry.map_model("glm-4.5*", "glm4_moe");
|
||||
registry.map_model("glm-4.6*", "glm4_moe");
|
||||
registry.map_model("glm-*", "json");
|
||||
|
||||
// Step3 models
|
||||
registry.map_model("step3*", "step3");
|
||||
registry.map_model("Step-3*", "step3");
|
||||
|
||||
// Kimi models
|
||||
registry.map_model("kimi-k2*", "kimik2");
|
||||
registry.map_model("Kimi-K2*", "kimik2");
|
||||
registry.map_model("moonshot*/Kimi-K2*", "kimik2");
|
||||
|
||||
// GPT-OSS models
|
||||
registry.map_model("gpt-oss*", "gpt_oss");
|
||||
registry.map_model("t4-*", "gpt_oss");
|
||||
|
||||
// Other models
|
||||
registry.map_model("gemini-*", "json");
|
||||
registry.map_model("palm-*", "json");
|
||||
registry.map_model("gemma-*", "json");
|
||||
}
|
||||
|
||||
/// Get a pooled parser for the given model ID.
|
||||
/// Returns a shared instance that can be used concurrently.
|
||||
/// Falls back to JSON parser if model is not recognized.
|
||||
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
|
||||
self.registry
|
||||
.get_pooled_for_model(model_id)
|
||||
.unwrap_or_else(|| {
|
||||
// Fallback to JSON parser
|
||||
self.registry
|
||||
.get_pooled_parser("json")
|
||||
.expect("JSON parser should always be registered")
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the internal registry for custom registration.
|
||||
pub fn registry(&self) -> &ToolParserRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
/// Clear the parser pool.
|
||||
pub fn clear_pool(&self) {
|
||||
self.registry.clear_pool();
|
||||
}
|
||||
|
||||
/// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
|
||||
/// This is useful for benchmarks and testing where you want independent parser instances.
|
||||
pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
|
||||
// Determine which parser type to use
|
||||
let parser_type = {
|
||||
let mapping = self.registry.model_mapping.read().unwrap();
|
||||
|
||||
// Try exact match first
|
||||
if let Some(parser_name) = mapping.get(model_id) {
|
||||
parser_name.clone()
|
||||
} else {
|
||||
// Try prefix matching
|
||||
let best_match = mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
pattern.ends_with('*')
|
||||
&& model_id.starts_with(&pattern[..pattern.len() - 1])
|
||||
})
|
||||
.max_by_key(|(pattern, _)| pattern.len());
|
||||
|
||||
if let Some((_, parser_name)) = best_match {
|
||||
parser_name.clone()
|
||||
} else {
|
||||
// Fall back to default
|
||||
self.registry.default_parser.read().unwrap().clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let creators = self.registry.creators.read().unwrap();
|
||||
creators.get(&parser_type).map(|creator| {
|
||||
// Call the creator to get a Box<dyn ToolParser>, then convert to Arc
|
||||
let boxed_parser = creator();
|
||||
Arc::from(boxed_parser)
|
||||
})
|
||||
}
|
||||
|
||||
/// List all registered parsers (for compatibility with old API).
|
||||
pub fn list_parsers(&self) -> Vec<String> {
|
||||
self.registry
|
||||
.creators
|
||||
.read()
|
||||
.unwrap()
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolParserFactory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn use_harmony_gpt_oss() -> bool {
|
||||
std::env::var("ROUTER_USE_HARMONY_GPT_OSS")
|
||||
.ok()
|
||||
.map(|value| {
|
||||
let normalized = value.trim();
|
||||
matches!(
|
||||
normalized,
|
||||
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
/// This module provides infrastructure for parsing tool calls from various model formats.
|
||||
// Core modules
|
||||
pub mod errors;
|
||||
pub mod factory;
|
||||
pub mod partial_json;
|
||||
pub mod registry;
|
||||
pub mod state;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
@@ -17,10 +17,9 @@ mod tests;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use errors::{ToolParserError, ToolParserResult};
|
||||
pub use registry::ParserRegistry;
|
||||
pub use state::{ParsePhase, ParseState};
|
||||
pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry};
|
||||
pub use traits::{PartialJsonParser, ToolParser};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamResult, ToolCall};
|
||||
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
|
||||
|
||||
// Re-export parsers for convenience
|
||||
pub use parsers::{
|
||||
|
||||
@@ -2,12 +2,13 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
/// DeepSeek V3 format parser for tool calls
|
||||
@@ -20,12 +21,29 @@ use crate::tool_parser::{
|
||||
/// - JSON arguments in code blocks
|
||||
/// - Support for multiple sequential tool calls
|
||||
pub struct DeepSeekParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting complete tool calls
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting function details
|
||||
func_detail_extractor: Regex,
|
||||
/// Regex for matching partial tool calls during streaming
|
||||
partial_tool_call_regex: Regex,
|
||||
/// Regex pattern for removing completed tool calls from buffer
|
||||
tool_call_end_pattern: Regex,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
}
|
||||
|
||||
impl DeepSeekParser {
|
||||
@@ -38,10 +56,24 @@ impl DeepSeekParser {
|
||||
let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
|
||||
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Partial pattern for streaming - uses .* (greedy) not .*? to match all partial content
|
||||
let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
|
||||
let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for removing completed tool calls
|
||||
let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
|
||||
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
tool_call_extractor,
|
||||
func_detail_extractor,
|
||||
partial_tool_call_regex,
|
||||
tool_call_end_pattern,
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,107 +175,146 @@ impl ToolParser for DeepSeekParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// Check if we have a tool call (either the start token or individual tool call)
|
||||
let has_tool_call =
|
||||
self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
|
||||
|
||||
if !has_tool_call {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
// Strip out end tokens if present
|
||||
let mut normal_text = std::mem::take(&mut self.buffer);
|
||||
for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
|
||||
normal_text = normal_text.replace(e_token, "");
|
||||
}
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Look for start of tool calls
|
||||
if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
|
||||
// Look for individual tool call start
|
||||
let search_from = start_pos + "<|tool▁calls▁begin|>".len();
|
||||
if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>")
|
||||
{
|
||||
let call_start_abs = search_from + call_start;
|
||||
// Build tool indices for validation
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
// Look for the end of this tool call
|
||||
let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len();
|
||||
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>")
|
||||
{
|
||||
let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len();
|
||||
let mut calls: Vec<ToolCallItem> = Vec::new();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
|
||||
// Try to match the partial tool call pattern
|
||||
if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
|
||||
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
|
||||
let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
match self.parse_tool_call(tool_call_text) {
|
||||
Ok(tool) => {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
Err(_) => {
|
||||
// Parsing failed, skip this tool call
|
||||
state.buffer.drain(..call_end_abs);
|
||||
}
|
||||
// Validate tool name
|
||||
if !tool_indices.contains_key(func_name) {
|
||||
// Invalid tool name - skip this tool, preserve indexing for next tool
|
||||
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
|
||||
helpers::reset_current_tool_state(
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&self.prev_tool_call_arr,
|
||||
);
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
// Initialize state if this is the first tool call
|
||||
if self.current_tool_id == -1 {
|
||||
self.current_tool_id = 0;
|
||||
self.prev_tool_call_arr = Vec::new();
|
||||
self.streamed_args_for_tool = vec![String::new()];
|
||||
}
|
||||
|
||||
// Ensure we have enough entries in our tracking arrays
|
||||
helpers::ensure_capacity(
|
||||
self.current_tool_id,
|
||||
&mut self.prev_tool_call_arr,
|
||||
&mut self.streamed_args_for_tool,
|
||||
);
|
||||
|
||||
// Send tool name if not sent yet
|
||||
if !self.current_tool_name_sent {
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: Some(func_name.to_string()),
|
||||
parameters: String::new(),
|
||||
});
|
||||
self.current_tool_name_sent = true;
|
||||
|
||||
// Store the tool call info for serving layer completions endpoint
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if self.prev_tool_call_arr.len() <= tool_id {
|
||||
self.prev_tool_call_arr
|
||||
.resize_with(tool_id + 1, || Value::Null);
|
||||
}
|
||||
self.prev_tool_call_arr[tool_id] = serde_json::json!({
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
});
|
||||
} else {
|
||||
// Compute incremental diff
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
let last_sent = self
|
||||
.streamed_args_for_tool
|
||||
.get(tool_id)
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let argument_diff = func_args_raw
|
||||
.strip_prefix(last_sent)
|
||||
.unwrap_or(func_args_raw);
|
||||
|
||||
if !argument_diff.is_empty() {
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: tool_id,
|
||||
name: None,
|
||||
parameters: argument_diff.to_string(),
|
||||
});
|
||||
if tool_id < self.streamed_args_for_tool.len() {
|
||||
self.streamed_args_for_tool[tool_id].push_str(argument_diff);
|
||||
}
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_end_from..];
|
||||
}
|
||||
|
||||
// Try to extract function name
|
||||
if let Some(sep_pos) = partial.find("<|tool▁sep|>") {
|
||||
if let Some(_func_start) = partial[..sep_pos].rfind("function") {
|
||||
// We have the function type marker
|
||||
let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..];
|
||||
|
||||
// Look for function name (ends at newline before ```json)
|
||||
if let Some(name_end) = after_sep.find("\n```json\n") {
|
||||
let func_name = after_sep[..name_end].trim();
|
||||
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial arguments
|
||||
let args_start = name_end + "\n```json\n".len();
|
||||
let partial_args = &after_sep[args_start..];
|
||||
|
||||
// Check if we can parse partial JSON
|
||||
if !partial_args.is_empty() {
|
||||
match self.partial_json.parse_value(partial_args) {
|
||||
Ok((value, _consumed)) => {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Can't parse yet, continue waiting for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check if JSON is complete
|
||||
if helpers::is_complete_json(func_args_raw) {
|
||||
// Update the stored arguments
|
||||
if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.prev_tool_call_arr.len() {
|
||||
if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
|
||||
obj.insert("arguments".to_string(), parsed_args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the end of the current tool call and remove only that part from buffer
|
||||
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
|
||||
// Remove the completed tool call from buffer, keep any remaining content
|
||||
self.buffer = current_text[mat.end()..].to_string();
|
||||
} else {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
let result = StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
};
|
||||
|
||||
self.current_tool_id += 1;
|
||||
self.current_tool_name_sent = false;
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -2,11 +2,13 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
/// GLM-4 MoE format parser for tool calls
|
||||
@@ -25,6 +27,22 @@ pub struct Glm4MoeParser {
|
||||
func_detail_extractor: Regex,
|
||||
/// Regex for extracting argument key-value pairs
|
||||
arg_extractor: Regex,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
eot_token: &'static str,
|
||||
}
|
||||
|
||||
impl Glm4MoeParser {
|
||||
@@ -44,12 +62,18 @@ impl Glm4MoeParser {
|
||||
tool_call_extractor,
|
||||
func_detail_extractor,
|
||||
arg_extractor,
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
bot_token: "<tool_call>",
|
||||
eot_token: "</tool_call>",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains GLM-4 MoE tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
text.contains(self.bot_token)
|
||||
}
|
||||
|
||||
/// Parse arguments from key-value pairs
|
||||
@@ -120,6 +144,25 @@ impl Glm4MoeParser {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
|
||||
/// Parse all tool calls from text (shared logic for complete and incremental parsing)
|
||||
fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
match self.parse_tool_call(mat.as_str()) {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Glm4MoeParser {
|
||||
@@ -140,18 +183,8 @@ impl ToolParser for Glm4MoeParser {
|
||||
let idx = text.find("<tool_call>").unwrap();
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
match self.parse_tool_call(mat.as_str()) {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Parse all tool calls using shared helper
|
||||
let tools = self.parse_tool_calls_from_text(text)?;
|
||||
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
@@ -162,78 +195,127 @@ impl ToolParser for Glm4MoeParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Python logic: Wait for complete tool call, then parse it all at once
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("<tool_call>") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
}
|
||||
|
||||
// Look for start of tool call
|
||||
if let Some(start_pos) = state.buffer.find("<tool_call>") {
|
||||
// Look for the end of this tool call
|
||||
let search_from = start_pos + "<tool_call>".len();
|
||||
if let Some(end_pos) = state.buffer[search_from..].find("</tool_call>") {
|
||||
let end_abs = search_from + end_pos + "</tool_call>".len();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[start_pos..end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
// Check if we have bot_token
|
||||
let start = current_text.find(self.bot_token);
|
||||
if start.is_none() {
|
||||
self.buffer.clear();
|
||||
// If we're in the middle of streaming (current_tool_id > 0), don't return text
|
||||
let normal_text = if self.current_tool_id > 0 {
|
||||
String::new()
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_from..];
|
||||
|
||||
// Try to extract function name (first line after <tool_call>)
|
||||
if let Some(name_end) = partial.find('\n') {
|
||||
let func_name = partial[..name_end].trim();
|
||||
|
||||
if !func_name.is_empty() && !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial arguments
|
||||
let args_text = &partial[name_end + 1..];
|
||||
let partial_args = self.parse_arguments(args_text)?;
|
||||
|
||||
if !partial_args.is_empty() {
|
||||
let args_str = serde_json::to_string(&partial_args)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
current_text.clone()
|
||||
};
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
// Check if we have eot_token (end of tool call)
|
||||
let end = current_text.find(self.eot_token);
|
||||
if let Some(end_pos) = end {
|
||||
// We have a complete tool call!
|
||||
|
||||
// Initialize state if this is the first tool call
|
||||
if self.current_tool_id == -1 {
|
||||
self.current_tool_id = 0;
|
||||
self.prev_tool_call_arr = Vec::new();
|
||||
self.streamed_args_for_tool = vec![String::new()];
|
||||
}
|
||||
|
||||
// Ensure we have enough entries in our tracking arrays
|
||||
helpers::ensure_capacity(
|
||||
self.current_tool_id,
|
||||
&mut self.prev_tool_call_arr,
|
||||
&mut self.streamed_args_for_tool,
|
||||
);
|
||||
|
||||
// Parse the complete block using shared helper
|
||||
let block_end = end_pos + self.eot_token.len();
|
||||
let parsed_tools = self.parse_tool_calls_from_text(¤t_text[..block_end])?;
|
||||
|
||||
// Extract normal text before tool calls
|
||||
let idx = current_text.find(self.bot_token);
|
||||
let normal_text = if let Some(pos) = idx {
|
||||
current_text[..pos].trim().to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Build tool indices for validation
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
let mut calls = Vec::new();
|
||||
|
||||
if !parsed_tools.is_empty() {
|
||||
// Take the first tool and convert to ToolCallItem
|
||||
let tool_call = &parsed_tools[0];
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
|
||||
// Validate tool name
|
||||
if !tool_indices.contains_key(&tool_call.function.name) {
|
||||
// Invalid tool name - skip this tool, preserve indexing for next tool
|
||||
tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name);
|
||||
helpers::reset_current_tool_state(
|
||||
&mut self.buffer,
|
||||
&mut false, // glm4_moe doesn't track name_sent per tool
|
||||
&mut self.streamed_args_for_tool,
|
||||
&self.prev_tool_call_arr,
|
||||
);
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: tool_id,
|
||||
name: Some(tool_call.function.name.clone()),
|
||||
parameters: tool_call.function.arguments.clone(),
|
||||
});
|
||||
|
||||
// Store in tracking arrays
|
||||
if self.prev_tool_call_arr.len() <= tool_id {
|
||||
self.prev_tool_call_arr
|
||||
.resize_with(tool_id + 1, || Value::Null);
|
||||
}
|
||||
|
||||
// Parse parameters as JSON and store
|
||||
if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
|
||||
self.prev_tool_call_arr[tool_id] = serde_json::json!({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": args,
|
||||
});
|
||||
}
|
||||
|
||||
if self.streamed_args_for_tool.len() <= tool_id {
|
||||
self.streamed_args_for_tool
|
||||
.resize_with(tool_id + 1, String::new);
|
||||
}
|
||||
self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();
|
||||
|
||||
self.current_tool_id += 1;
|
||||
}
|
||||
|
||||
// Remove processed portion from buffer
|
||||
self.buffer = current_text[block_end..].to_string();
|
||||
return Ok(StreamingParseResult { normal_text, calls });
|
||||
}
|
||||
|
||||
// No complete tool call yet - return normal text before start token
|
||||
let start_pos = start.unwrap();
|
||||
let normal_text = current_text[..start_pos].to_string();
|
||||
self.buffer = current_text[start_pos..].to_string();
|
||||
|
||||
Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
traits::{TokenToolParser, ToolParser},
|
||||
types::{StreamResult, ToolCall},
|
||||
types::{StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Placeholder for the Harmony-backed GPT-OSS parser.
|
||||
@@ -29,12 +30,12 @@ impl ToolParser for GptOssHarmonyParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
_chunk: &str,
|
||||
_state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
_tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Temporary stub until the Harmony streaming pipeline is implemented.
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(StreamingParseResult::default())
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
@@ -61,10 +62,10 @@ impl TokenToolParser for GptOssHarmonyParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental_tokens(
|
||||
&self,
|
||||
&mut self,
|
||||
_tokens: &[u32],
|
||||
_state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
Ok(StreamResult::Incomplete)
|
||||
_tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
Ok(StreamingParseResult::default())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
/// GPT-OSS format parser for tool calls
|
||||
@@ -26,6 +28,11 @@ pub struct GptOssParser {
|
||||
function_call_extractor: Regex,
|
||||
/// Regex for extracting streaming function calls
|
||||
streaming_extractor: Regex,
|
||||
|
||||
/// Buffer for accumulating chunks
|
||||
buffer: String,
|
||||
/// Whether the tool name has been sent (for streaming)
|
||||
name_sent: bool,
|
||||
}
|
||||
|
||||
impl GptOssParser {
|
||||
@@ -45,6 +52,9 @@ impl GptOssParser {
|
||||
partial_json: PartialJson::default(),
|
||||
function_call_extractor,
|
||||
streaming_extractor,
|
||||
|
||||
buffer: String::new(),
|
||||
name_sent: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,21 +133,21 @@ impl ToolParser for GptOssParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
if !self.has_tool_markers(&self.buffer) {
|
||||
// No markers found, clear buffer and return
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::Incomplete);
|
||||
self.buffer.clear();
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
// Try to match streaming pattern
|
||||
if let Some(captures) = self.streaming_extractor.captures(&state.buffer) {
|
||||
if let Some(captures) = self.streaming_extractor.captures(&self.buffer) {
|
||||
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
|
||||
let full_function_name = name_match.as_str();
|
||||
let partial_args = args_match.as_str();
|
||||
@@ -146,16 +156,30 @@ impl ToolParser for GptOssParser {
|
||||
let function_name = self.extract_function_name(full_function_name);
|
||||
|
||||
// Send function name if not sent yet
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: function_name.clone(),
|
||||
if !self.name_sent {
|
||||
// Validate tool name
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
if !tool_indices.contains_key(&function_name) {
|
||||
// Invalid tool name - skip
|
||||
tracing::warn!("Invalid tool name '{}' - skipping", function_name);
|
||||
self.buffer.clear();
|
||||
self.name_sent = false;
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
self.name_sent = true; // Mark name as sent
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls: vec![ToolCallItem {
|
||||
tool_index: 0,
|
||||
name: Some(function_name.clone()),
|
||||
parameters: String::new(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
|
||||
// Check if we have a complete function call
|
||||
if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) {
|
||||
if let Some(complete_match) = self.function_call_extractor.captures(&self.buffer) {
|
||||
if let Some(args_match) = complete_match.get(2) {
|
||||
let args_content = args_match.as_str().trim();
|
||||
|
||||
@@ -170,26 +194,22 @@ impl ToolParser for GptOssParser {
|
||||
}
|
||||
};
|
||||
|
||||
// Generate unique ID
|
||||
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
let tool = ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: function_name,
|
||||
arguments,
|
||||
},
|
||||
};
|
||||
|
||||
// Remove the processed part from buffer
|
||||
let complete_end = complete_match.get(0).unwrap().end();
|
||||
state.buffer.drain(..complete_end);
|
||||
self.buffer.drain(..complete_end);
|
||||
|
||||
// Reset state for next tool
|
||||
state.in_string = false;
|
||||
self.name_sent = false;
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
// Return final arguments
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls: vec![ToolCallItem {
|
||||
tool_index: 0,
|
||||
name: None,
|
||||
parameters: arguments,
|
||||
}],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Try to parse partial JSON for streaming arguments
|
||||
@@ -206,9 +226,13 @@ impl ToolParser for GptOssParser {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls: vec![ToolCallItem {
|
||||
tool_index: 0,
|
||||
name: None,
|
||||
parameters: args_str,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
@@ -220,7 +244,7 @@ impl ToolParser for GptOssParser {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(StreamingParseResult::default())
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
398
sgl-router/src/tool_parser/parsers/helpers.rs
Normal file
398
sgl-router/src/tool_parser/parsers/helpers.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
use crate::protocols::spec::Tool;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
|
||||
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
|
||||
|
||||
/// Get a mapping of tool names to their indices
|
||||
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
|
||||
tools
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, tool)| (tool.function.name.clone(), i))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a buffer ends with a partial occurrence of a token
|
||||
/// Returns Some(length) if there's a partial match, None otherwise
|
||||
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
|
||||
if buffer.is_empty() || token.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
(1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
|
||||
}
|
||||
|
||||
/// Reset state for the current tool being parsed (used when skipping invalid tools).
|
||||
/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
|
||||
/// but clears the state specific to the current incomplete tool.
|
||||
pub fn reset_current_tool_state(
|
||||
buffer: &mut String,
|
||||
current_tool_name_sent: &mut bool,
|
||||
streamed_args_for_tool: &mut Vec<String>,
|
||||
prev_tool_call_arr: &[Value],
|
||||
) {
|
||||
buffer.clear();
|
||||
*current_tool_name_sent = false;
|
||||
|
||||
// Only pop if we added an entry for the current (invalid) tool
|
||||
// streamed_args_for_tool should match prev_tool_call_arr length for completed tools
|
||||
if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
|
||||
streamed_args_for_tool.pop();
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the entire parser state (used at the start of a new request).
|
||||
/// Clears all accumulated tool calls and resets all state to initial values.
|
||||
pub fn reset_parser_state(
|
||||
buffer: &mut String,
|
||||
prev_tool_call_arr: &mut Vec<Value>,
|
||||
current_tool_id: &mut i32,
|
||||
current_tool_name_sent: &mut bool,
|
||||
streamed_args_for_tool: &mut Vec<String>,
|
||||
) {
|
||||
buffer.clear();
|
||||
prev_tool_call_arr.clear();
|
||||
*current_tool_id = 0;
|
||||
*current_tool_name_sent = false;
|
||||
streamed_args_for_tool.clear();
|
||||
}
|
||||
|
||||
/// Ensure arrays have capacity for the given tool ID
|
||||
pub fn ensure_capacity(
|
||||
current_tool_id: i32,
|
||||
prev_tool_call_arr: &mut Vec<Value>,
|
||||
streamed_args_for_tool: &mut Vec<String>,
|
||||
) {
|
||||
if current_tool_id < 0 {
|
||||
return;
|
||||
}
|
||||
let needed = (current_tool_id + 1) as usize;
|
||||
|
||||
if prev_tool_call_arr.len() < needed {
|
||||
prev_tool_call_arr.resize_with(needed, || Value::Null);
|
||||
}
|
||||
if streamed_args_for_tool.len() < needed {
|
||||
streamed_args_for_tool.resize_with(needed, String::new);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a string contains complete, valid JSON
|
||||
pub fn is_complete_json(input: &str) -> bool {
|
||||
serde_json::from_str::<Value>(input).is_ok()
|
||||
}
|
||||
|
||||
/// Normalize the arguments/parameters field in a tool call object.
|
||||
/// If the object has "parameters" but not "arguments", copy parameters to arguments.
|
||||
///
|
||||
/// # Background
|
||||
/// Different LLM formats use different field names:
|
||||
/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
|
||||
/// - Mistral and Qwen use "arguments"
|
||||
///
|
||||
/// This function normalizes to "arguments" for consistent downstream processing.
|
||||
pub fn normalize_arguments_field(mut obj: Value) -> Value {
|
||||
if obj.get("arguments").is_none() {
|
||||
if let Some(params) = obj.get("parameters").cloned() {
|
||||
if let Value::Object(ref mut map) = obj {
|
||||
map.insert("arguments".to_string(), params);
|
||||
}
|
||||
}
|
||||
}
|
||||
obj
|
||||
}
|
||||
|
||||
/// Handle the entire JSON tool call streaming process for JSON-based parsers.
|
||||
///
|
||||
/// This unified function handles all aspects of streaming tool calls:
|
||||
/// - Parsing partial JSON from the buffer
|
||||
/// - Validating tool names against available tools
|
||||
/// - Streaming tool names (Case 1)
|
||||
/// - Streaming tool arguments (Case 2)
|
||||
/// - Managing parser state and buffer updates
|
||||
///
|
||||
/// Used by JSON, Llama, Mistral, and Qwen parsers.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `current_text`: The current buffered text being parsed
|
||||
/// - `start_idx`: Start index of JSON content in current_text
|
||||
/// - `partial_json`: Mutable reference to partial JSON parser
|
||||
/// - `tool_indices`: Map of valid tool names to their indices
|
||||
/// - `buffer`: Mutable parser buffer
|
||||
/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
|
||||
/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
|
||||
/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
|
||||
/// - `prev_tool_call_arr`: Mutable array of previous tool call states
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(StreamingParseResult)` with any tool call items to stream
|
||||
/// - `Err(ToolParserError)` if JSON parsing or serialization fails
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn handle_json_tool_streaming(
|
||||
current_text: &str,
|
||||
start_idx: usize,
|
||||
partial_json: &mut crate::tool_parser::partial_json::PartialJson,
|
||||
tool_indices: &HashMap<String, usize>,
|
||||
buffer: &mut String,
|
||||
current_tool_id: &mut i32,
|
||||
current_tool_name_sent: &mut bool,
|
||||
streamed_args_for_tool: &mut Vec<String>,
|
||||
prev_tool_call_arr: &mut Vec<Value>,
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Check if we have content to parse
|
||||
if start_idx >= current_text.len() {
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
// Extract JSON string from current position
|
||||
let json_str = ¤t_text[start_idx..];
|
||||
|
||||
// Parse partial JSON
|
||||
let (obj, end_idx) = match partial_json.parse_value(json_str) {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
};
|
||||
|
||||
// Check if JSON is complete
|
||||
let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok();
|
||||
|
||||
// Validate tool name if present
|
||||
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
|
||||
if !tool_indices.contains_key(name) {
|
||||
// Invalid tool name - skip this tool, preserve indexing for next tool
|
||||
tracing::warn!("Invalid tool name '{}' - skipping", name);
|
||||
reset_current_tool_state(
|
||||
buffer,
|
||||
current_tool_name_sent,
|
||||
streamed_args_for_tool,
|
||||
prev_tool_call_arr,
|
||||
);
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize parameters/arguments field
|
||||
let current_tool_call = normalize_arguments_field(obj);
|
||||
|
||||
let mut result = StreamingParseResult::default();
|
||||
|
||||
// Case 1: Handle tool name streaming
|
||||
if !*current_tool_name_sent {
|
||||
if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
|
||||
if tool_indices.contains_key(function_name) {
|
||||
// Initialize if first tool
|
||||
if *current_tool_id == -1 {
|
||||
*current_tool_id = 0;
|
||||
streamed_args_for_tool.push(String::new());
|
||||
} else if *current_tool_id as usize >= streamed_args_for_tool.len() {
|
||||
// Ensure capacity for subsequent tools
|
||||
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
|
||||
}
|
||||
|
||||
// Send tool name with empty parameters
|
||||
*current_tool_name_sent = true;
|
||||
result.calls.push(ToolCallItem {
|
||||
tool_index: *current_tool_id as usize,
|
||||
name: Some(function_name.to_string()),
|
||||
parameters: String::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Case 2: Handle streaming arguments
|
||||
else if let Some(cur_arguments) = current_tool_call.get("arguments") {
|
||||
let tool_id = *current_tool_id as usize;
|
||||
let sent = streamed_args_for_tool
|
||||
.get(tool_id)
|
||||
.map(|s| s.len())
|
||||
.unwrap_or(0);
|
||||
let cur_args_json = serde_json::to_string(cur_arguments)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Compute diff: everything after what we've already sent
|
||||
let diff = cur_args_json[sent..].to_string();
|
||||
|
||||
// Send diff if there's new content
|
||||
if !diff.is_empty() {
|
||||
// Only accumulate if not complete
|
||||
if !is_complete && tool_id < streamed_args_for_tool.len() {
|
||||
streamed_args_for_tool[tool_id].push_str(&diff);
|
||||
}
|
||||
|
||||
result.calls.push(ToolCallItem {
|
||||
tool_index: tool_id,
|
||||
name: None,
|
||||
parameters: diff,
|
||||
});
|
||||
}
|
||||
|
||||
// If JSON is complete, advance to next tool
|
||||
if is_complete {
|
||||
// Remove processed portion, keep unprocessed content
|
||||
*buffer = current_text[start_idx + end_idx..].to_string();
|
||||
|
||||
// Clear completed tool data
|
||||
if tool_id < prev_tool_call_arr.len() {
|
||||
prev_tool_call_arr[tool_id] = Value::Null;
|
||||
}
|
||||
*current_tool_name_sent = false;
|
||||
if tool_id < streamed_args_for_tool.len() {
|
||||
streamed_args_for_tool[tool_id].clear();
|
||||
}
|
||||
*current_tool_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Update prev_tool_call_arr with current state
|
||||
if *current_tool_id >= 0 {
|
||||
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
|
||||
let tool_id = *current_tool_id as usize;
|
||||
|
||||
if tool_id < prev_tool_call_arr.len() {
|
||||
prev_tool_call_arr[tool_id] = current_tool_call;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ends_with_partial_token() {
|
||||
assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
|
||||
assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
|
||||
assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
|
||||
assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
|
||||
assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_current_tool_state() {
|
||||
let mut buffer = String::from("partial json");
|
||||
let mut current_tool_name_sent = true;
|
||||
let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
|
||||
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
|
||||
|
||||
reset_current_tool_state(
|
||||
&mut buffer,
|
||||
&mut current_tool_name_sent,
|
||||
&mut streamed_args,
|
||||
&prev_tools,
|
||||
);
|
||||
|
||||
assert_eq!(buffer, "");
|
||||
assert!(!current_tool_name_sent);
|
||||
assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
|
||||
assert_eq!(streamed_args[0], "tool0_args");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_current_tool_state_no_pop_when_synced() {
|
||||
let mut buffer = String::from("partial json");
|
||||
let mut current_tool_name_sent = true;
|
||||
let mut streamed_args = vec!["tool0_args".to_string()];
|
||||
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
|
||||
|
||||
reset_current_tool_state(
|
||||
&mut buffer,
|
||||
&mut current_tool_name_sent,
|
||||
&mut streamed_args,
|
||||
&prev_tools,
|
||||
);
|
||||
|
||||
assert_eq!(buffer, "");
|
||||
assert!(!current_tool_name_sent);
|
||||
assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_parser_state() {
|
||||
let mut buffer = String::from("some buffer");
|
||||
let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
|
||||
let mut current_tool_id = 5;
|
||||
let mut current_tool_name_sent = true;
|
||||
let mut streamed_args = vec!["args".to_string()];
|
||||
|
||||
reset_parser_state(
|
||||
&mut buffer,
|
||||
&mut prev_tools,
|
||||
&mut current_tool_id,
|
||||
&mut current_tool_name_sent,
|
||||
&mut streamed_args,
|
||||
);
|
||||
|
||||
assert_eq!(buffer, "");
|
||||
assert_eq!(prev_tools.len(), 0);
|
||||
assert_eq!(current_tool_id, 0);
|
||||
assert!(!current_tool_name_sent);
|
||||
assert_eq!(streamed_args.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ensure_capacity() {
|
||||
let mut prev_tools = vec![];
|
||||
let mut streamed_args = vec![];
|
||||
|
||||
ensure_capacity(2, &mut prev_tools, &mut streamed_args);
|
||||
|
||||
assert_eq!(prev_tools.len(), 3);
|
||||
assert_eq!(streamed_args.len(), 3);
|
||||
assert_eq!(prev_tools[0], Value::Null);
|
||||
assert_eq!(streamed_args[0], "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ensure_capacity_negative_id() {
|
||||
let mut prev_tools = vec![];
|
||||
let mut streamed_args = vec![];
|
||||
|
||||
ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
|
||||
|
||||
// Should not resize for negative ID
|
||||
assert_eq!(prev_tools.len(), 0);
|
||||
assert_eq!(streamed_args.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_complete_json() {
|
||||
assert!(is_complete_json(r#"{"name": "test"}"#));
|
||||
assert!(is_complete_json("[1, 2, 3]"));
|
||||
assert!(is_complete_json("42"));
|
||||
assert!(is_complete_json("true"));
|
||||
assert!(!is_complete_json(r#"{"name": "#));
|
||||
assert!(!is_complete_json("[1, 2,"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_arguments_field() {
|
||||
// Case 1: Has parameters, no arguments
|
||||
let obj = serde_json::json!({
|
||||
"name": "test",
|
||||
"parameters": {"key": "value"}
|
||||
});
|
||||
let normalized = normalize_arguments_field(obj);
|
||||
assert_eq!(
|
||||
normalized.get("arguments").unwrap(),
|
||||
&serde_json::json!({"key": "value"})
|
||||
);
|
||||
|
||||
// Case 2: Already has arguments
|
||||
let obj = serde_json::json!({
|
||||
"name": "test",
|
||||
"arguments": {"key": "value"}
|
||||
});
|
||||
let normalized = normalize_arguments_field(obj.clone());
|
||||
assert_eq!(normalized, obj);
|
||||
|
||||
// Case 3: No parameters or arguments
|
||||
let obj = serde_json::json!({"name": "test"});
|
||||
let normalized = normalize_arguments_field(obj.clone());
|
||||
assert_eq!(normalized, obj);
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// JSON format parser for tool calls
|
||||
@@ -18,6 +20,24 @@ use crate::tool_parser::{
|
||||
pub struct JsonParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Separator between multiple tool calls
|
||||
tool_call_separator: &'static str,
|
||||
}
|
||||
|
||||
impl JsonParser {
|
||||
@@ -25,6 +45,12 @@ impl JsonParser {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
tool_call_separator: ",",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,25 +184,9 @@ impl JsonParser {
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Check if text contains JSON tool call markers (complete markers)
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
(text.contains('{') || text.contains('[')) && text.contains("name")
|
||||
}
|
||||
|
||||
/// Check if buffer could be building toward a tool call pattern
|
||||
fn has_partial_start_token(&self, buffer: &str) -> bool {
|
||||
// Check if buffer ends with a partial match of tool call patterns
|
||||
let patterns = [r#"{"name""#, r#"[{"name""#];
|
||||
|
||||
for pattern in &patterns {
|
||||
// Check if buffer ends with any partial of this pattern
|
||||
for i in 1..=buffer.len().min(pattern.len()) {
|
||||
if pattern.starts_with(&buffer[buffer.len() - i..]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
/// Check if text contains tool calls
|
||||
fn has_tool_call(&self, text: &str) -> bool {
|
||||
text.contains('[') || text.contains('{')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,79 +216,62 @@ impl ToolParser for JsonParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
let trimmed = state.buffer.trim();
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Append new text to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// If no tool markers and not a partial token, return as normal text │ │
|
||||
if !self.has_tool_markers(trimmed) && !self.has_partial_start_token(trimmed) {
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
// Check if current_text has tool_call
|
||||
let has_tool_start = self.has_tool_call(current_text)
|
||||
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|
||||
|
||||
if !has_tool_start {
|
||||
let normal_text = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(trimmed) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == trimmed.len() {
|
||||
// Check if this is truly complete
|
||||
let looks_complete = trimmed.ends_with('}') || trimmed.ends_with(']');
|
||||
// Build tool indices
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
if looks_complete {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool as complete
|
||||
// TODO simplified version, address more complex version
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON, try to extract tool name
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// TODO simplified version, address more complex version
|
||||
// Just return the tool name once we see it
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as a flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for complete arguments
|
||||
if let Some(args) =
|
||||
value.get("arguments").or_else(|| value.get("parameters"))
|
||||
{
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
// Return arguments as a single update
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Determine start index for JSON parsing
|
||||
// JSON can start with [ (array) or { (single object)
|
||||
let start_idx = if let Some(bracket_pos) = current_text.find('[') {
|
||||
let brace_pos = current_text.find('{');
|
||||
match brace_pos {
|
||||
Some(bp) if bp < bracket_pos => bp,
|
||||
_ => bracket_pos,
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Continue waiting for more data
|
||||
}
|
||||
}
|
||||
} else if let Some(brace_pos) = current_text.find('{') {
|
||||
brace_pos
|
||||
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
|
||||
self.tool_call_separator.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
helpers::handle_json_tool_streaming(
|
||||
current_text,
|
||||
start_idx,
|
||||
&mut self.partial_json,
|
||||
&tool_indices,
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_id,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&mut self.prev_tool_call_arr,
|
||||
)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
self.has_tool_markers(text)
|
||||
let trimmed = text.trim();
|
||||
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
/// Kimi K2 format parser for tool calls
|
||||
@@ -19,12 +21,32 @@ use crate::tool_parser::{
|
||||
/// - Function calls with explicit indexing
|
||||
/// - JSON arguments
|
||||
pub struct KimiK2Parser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting complete tool calls
|
||||
tool_call_extractor: Regex,
|
||||
/// Regex for extracting partial tool calls (streaming)
|
||||
stream_tool_call_extractor: Regex,
|
||||
/// Regex pattern for removing completed tool calls from buffer
|
||||
tool_call_end_pattern: Regex,
|
||||
/// Robust parser for ids like "functions.search:0" or fallback "search:0"
|
||||
tool_call_id_regex: Regex,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Tracks the last arguments sent for incremental diffing
|
||||
last_arguments: String,
|
||||
}
|
||||
|
||||
impl KimiK2Parser {
|
||||
@@ -38,10 +60,25 @@ impl KimiK2Parser {
|
||||
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
|
||||
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Pattern for removing completed tool calls
|
||||
let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
|
||||
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
|
||||
|
||||
// Robust parser for ids like "functions.search:0" or fallback "search:0"
|
||||
let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
|
||||
let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
|
||||
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
tool_call_extractor,
|
||||
stream_tool_call_extractor,
|
||||
tool_call_end_pattern,
|
||||
tool_call_id_regex,
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
last_arguments: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,22 +89,13 @@ impl KimiK2Parser {
|
||||
|
||||
/// Parse function ID to extract name and index
|
||||
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
|
||||
// Format: functions.{name}:{index} or namespace.functions.{name}:{index}
|
||||
// Extract everything after the last dot before the colon as the function name
|
||||
if let Some(colon_pos) = id.rfind(':') {
|
||||
let before_colon = &id[..colon_pos];
|
||||
let index_str = &id[colon_pos + 1..];
|
||||
|
||||
// Find the last dot to extract the function name
|
||||
if let Some(dot_pos) = before_colon.rfind('.') {
|
||||
let func_name = &before_colon[dot_pos + 1..];
|
||||
|
||||
if let Ok(index) = index_str.parse::<usize>() {
|
||||
return Some((func_name.to_string(), index));
|
||||
}
|
||||
}
|
||||
if let Some(captures) = self.tool_call_id_regex.captures(id) {
|
||||
let name = captures.name("name")?.as_str().to_string();
|
||||
let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
|
||||
Some((name, index))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,107 +168,172 @@ impl ToolParser for KimiK2Parser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check for tool markers
|
||||
// Check if we have a tool call (either the start token or individual tool call)
|
||||
let has_tool_call =
|
||||
self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>");
|
||||
self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
|
||||
|
||||
if !has_tool_call {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
let marker1_pos = state.buffer.find("<|tool_calls_section_begin|>");
|
||||
let marker2_pos = state.buffer.find("<|tool_call_begin|>");
|
||||
let marker_pos = marker1_pos.iter().chain(marker2_pos.iter()).min().copied();
|
||||
|
||||
if let Some(pos) = marker_pos {
|
||||
if pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
let mut normal_text = std::mem::take(&mut self.buffer);
|
||||
// Remove end tokens if present
|
||||
for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
|
||||
normal_text = normal_text.replace(e_token, "");
|
||||
}
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Build tool indices for validation
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
let mut calls: Vec<ToolCallItem> = Vec::new();
|
||||
|
||||
// Try to match streaming pattern
|
||||
if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) {
|
||||
if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
|
||||
if let (Some(id_match), Some(args_match)) = (
|
||||
captures.name("tool_call_id"),
|
||||
captures.name("function_arguments"),
|
||||
) {
|
||||
let function_id = id_match.as_str();
|
||||
let partial_args = args_match.as_str();
|
||||
let function_args = args_match.as_str();
|
||||
|
||||
// Parse function ID
|
||||
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
|
||||
// Send function name if not sent yet
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.clone(),
|
||||
});
|
||||
// Validate tool name
|
||||
if !tool_indices.contains_key(&func_name) {
|
||||
// Invalid tool name - skip this tool, preserve indexing for next tool
|
||||
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
|
||||
helpers::reset_current_tool_state(
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&self.prev_tool_call_arr,
|
||||
);
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
|
||||
// Check if we have a complete tool call
|
||||
if let Some(end_pos) = partial_args.find("<|tool_call_end|>") {
|
||||
// Extract just the JSON part
|
||||
let json_args = &partial_args[..end_pos];
|
||||
// Initialize state if this is the first tool call
|
||||
if self.current_tool_id == -1 {
|
||||
self.current_tool_id = 0;
|
||||
self.prev_tool_call_arr = Vec::new();
|
||||
self.streamed_args_for_tool = vec![String::new()];
|
||||
}
|
||||
|
||||
// Validate and parse JSON
|
||||
if serde_json::from_str::<serde_json::Value>(json_args).is_ok() {
|
||||
// Generate unique ID
|
||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
||||
// Ensure we have enough entries in our tracking arrays
|
||||
helpers::ensure_capacity(
|
||||
self.current_tool_id,
|
||||
&mut self.prev_tool_call_arr,
|
||||
&mut self.streamed_args_for_tool,
|
||||
);
|
||||
|
||||
let tool = ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name,
|
||||
arguments: json_args.to_string(),
|
||||
},
|
||||
// Send tool name if not sent yet
|
||||
if !self.current_tool_name_sent {
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: Some(func_name.clone()),
|
||||
parameters: String::new(),
|
||||
});
|
||||
self.current_tool_name_sent = true;
|
||||
|
||||
// Store the tool call info for serving layer completions endpoint
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if self.prev_tool_call_arr.len() <= tool_id {
|
||||
self.prev_tool_call_arr
|
||||
.resize_with(tool_id + 1, || Value::Null);
|
||||
}
|
||||
self.prev_tool_call_arr[tool_id] = serde_json::json!({
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
});
|
||||
} else {
|
||||
// Compute incremental diff
|
||||
let argument_diff = if function_args.starts_with(&self.last_arguments) {
|
||||
&function_args[self.last_arguments.len()..]
|
||||
} else {
|
||||
function_args
|
||||
};
|
||||
|
||||
// Split by end token before sending (like Python does)
|
||||
let parsed_args_diff =
|
||||
if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
|
||||
&argument_diff[..pos]
|
||||
} else {
|
||||
argument_diff
|
||||
};
|
||||
|
||||
// Find where this tool call ends in the buffer
|
||||
if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") {
|
||||
let end_pos = tool_end + "<|tool_call_end|>".len();
|
||||
state.buffer.drain(..end_pos);
|
||||
if !parsed_args_diff.is_empty() {
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: None,
|
||||
parameters: parsed_args_diff.to_string(),
|
||||
});
|
||||
// Note: Python adds full diff to _last_arguments, not just parsed part
|
||||
self.last_arguments.push_str(argument_diff);
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.streamed_args_for_tool.len() {
|
||||
self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
|
||||
}
|
||||
|
||||
// Reset state for next tool
|
||||
state.in_string = false;
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
} else {
|
||||
// Try to parse partial JSON for streaming arguments
|
||||
match self.partial_json.parse_value(partial_args) {
|
||||
Ok((value, _consumed)) => {
|
||||
let args_str = serde_json::to_string(&value)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
// Check completeness - split by end token first
|
||||
let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
|
||||
{
|
||||
&function_args[..pos]
|
||||
} else {
|
||||
function_args
|
||||
};
|
||||
|
||||
if helpers::is_complete_json(parsed_args) {
|
||||
// Update the stored arguments
|
||||
if let Ok(parsed_args_value) =
|
||||
serde_json::from_str::<Value>(parsed_args)
|
||||
{
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.prev_tool_call_arr.len() {
|
||||
if let Some(obj) =
|
||||
self.prev_tool_call_arr[tool_id].as_object_mut()
|
||||
{
|
||||
obj.insert("arguments".to_string(), parsed_args_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Can't parse yet, keep buffering
|
||||
|
||||
// Find the end of the current tool call and remove only that part from buffer
|
||||
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
|
||||
// Remove the completed tool call from buffer, keep any remaining content
|
||||
self.buffer = current_text[mat.end()..].to_string();
|
||||
} else {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
let result = StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
};
|
||||
|
||||
self.current_tool_id += 1;
|
||||
self.last_arguments.clear();
|
||||
self.current_tool_name_sent = false;
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -2,23 +2,44 @@ use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use uuid;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Llama 3.2 format parser for tool calls
|
||||
///
|
||||
/// Handles the Llama 3.2 specific format:
|
||||
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
|
||||
/// `<|python_tag|>{"name": "func", "parameters": {...}}`
|
||||
///
|
||||
/// Also supports plain JSON without the python_tag prefix
|
||||
pub struct LlamaParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
tool_call_separator: &'static str,
|
||||
}
|
||||
|
||||
impl LlamaParser {
|
||||
@@ -26,6 +47,13 @@ impl LlamaParser {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
bot_token: "<|python_tag|>",
|
||||
tool_call_separator: ";",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,39 +104,6 @@ impl LlamaParser {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse JSON value(s) into tool calls
|
||||
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
match value {
|
||||
Value::Array(arr) => {
|
||||
// Parse each element in the array
|
||||
for item in arr {
|
||||
if let Some(tool) = self.parse_single_object(item)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::Object(_) => {
|
||||
// Single tool call
|
||||
if let Some(tool) = self.parse_single_object(value)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Not a valid tool call format
|
||||
return Ok(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Check if text contains potential tool call markers
|
||||
fn has_python_tag(&self, text: &str) -> bool {
|
||||
text.contains("<|python_tag|>")
|
||||
}
|
||||
|
||||
/// Parse semicolon-separated JSON objects
|
||||
fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
let mut all_tools = Vec::new();
|
||||
@@ -136,6 +131,11 @@ impl LlamaParser {
|
||||
|
||||
Ok(all_tools)
|
||||
}
|
||||
|
||||
/// Check if text has tool call
|
||||
fn has_tool_call(&self, text: &str) -> bool {
|
||||
text.contains("<|python_tag|>") || text.contains('{')
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaParser {
|
||||
@@ -185,137 +185,57 @@ impl ToolParser for LlamaParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Append new text to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// In streaming mode, be more lenient - check for potential JSON start
|
||||
let has_potential_json = state.buffer.contains('{');
|
||||
let has_tag = self.has_python_tag(&state.buffer);
|
||||
// Check if current_text has tool_call
|
||||
let has_tool_start = self.has_tool_call(current_text)
|
||||
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|
||||
|
||||
// If we have neither python_tag nor potential JSON structure, return as normal text
|
||||
if !has_tag && !has_potential_json {
|
||||
// No relevant markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
if !has_tool_start {
|
||||
// Only clear buffer if we're sure no tool call is starting
|
||||
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
|
||||
let normal_text = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
// If we only have '{' without more content, wait for more data
|
||||
let trimmed = state.buffer.trim();
|
||||
if (trimmed == "{") && !has_tag {
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
|
||||
// Check for text before python_tag and extract it as normal text
|
||||
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") {
|
||||
if tag_pos > 0 {
|
||||
// We have text before the python_tag - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..tag_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
} else {
|
||||
// For JSON without python_tag, look for the start of JSON structure
|
||||
let brace_pos = state.buffer.find('{');
|
||||
let bracket_pos = state.buffer.find('[');
|
||||
let json_pos = brace_pos.iter().chain(bracket_pos.iter()).min().copied();
|
||||
|
||||
if let Some(pos) = json_pos {
|
||||
if pos > 0 {
|
||||
// We have text before JSON structure - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract JSON content based on whether we have python_tag
|
||||
let (json_content, content_start_pos) = if self.has_python_tag(&state.buffer) {
|
||||
// Extract content after python_tag
|
||||
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") {
|
||||
let start = tag_pos + "<|python_tag|>".len();
|
||||
(&state.buffer[start..], start)
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
} else {
|
||||
(&state.buffer[..], 0)
|
||||
// Might be partial bot_token, keep buffering
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Build tool indices
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
// Determine start index for JSON parsing
|
||||
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
|
||||
pos + self.bot_token.len()
|
||||
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
|
||||
self.tool_call_separator.len()
|
||||
} else {
|
||||
// Find where the actual content starts after trimming
|
||||
let trimmed = state.buffer.trim_start();
|
||||
let trim_offset = state.buffer.len() - trimmed.len();
|
||||
(trimmed.trim_end(), trim_offset)
|
||||
0
|
||||
};
|
||||
|
||||
// Check if we have a semicolon separator (multiple tools)
|
||||
if let Some(semicolon_pos) = json_content.find(';') {
|
||||
// We have multiple tools - try to parse the first one
|
||||
let first_json = &json_content[..semicolon_pos];
|
||||
|
||||
if let Ok(value) = serde_json::from_str::<Value>(first_json.trim()) {
|
||||
if let Some(tool) = self.parse_single_object(&value)? {
|
||||
// Remove the parsed JSON and semicolon from the buffer
|
||||
let end_pos = content_start_pos + semicolon_pos + 1; // +1 to include the semicolon
|
||||
state.buffer.drain(content_start_pos..end_pos);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(json_content) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_content.len() {
|
||||
// Check if this is truly complete
|
||||
let looks_complete = json_content.ends_with('}') || json_content.ends_with(']');
|
||||
|
||||
if looks_complete {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
|
||||
// Return the first tool as complete
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON, try to extract tool name for streaming
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// Return tool name once we see it
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as a flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for complete arguments
|
||||
if let Some(args) =
|
||||
value.get("arguments").or_else(|| value.get("parameters"))
|
||||
{
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Continue waiting for more data
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
helpers::handle_json_tool_streaming(
|
||||
current_text,
|
||||
start_idx,
|
||||
&mut self.partial_json,
|
||||
&tool_indices,
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_id,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&mut self.prev_tool_call_arr,
|
||||
)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Mistral format parser for tool calls
|
||||
@@ -21,6 +23,25 @@ use crate::tool_parser::{
|
||||
pub struct MistralParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
tool_call_separator: &'static str,
|
||||
}
|
||||
|
||||
impl MistralParser {
|
||||
@@ -28,19 +49,16 @@ impl MistralParser {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
bot_token: "[TOOL_CALLS] [",
|
||||
tool_call_separator: ", ",
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract JSON array using bracket counting
|
||||
///
|
||||
/// Handles nested brackets in JSON content by tracking:
|
||||
/// - String boundaries (quotes)
|
||||
/// - Escape sequences
|
||||
/// - Bracket depth
|
||||
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
|
||||
self.extract_json_array_with_pos(text).map(|(_, json)| json)
|
||||
}
|
||||
|
||||
fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
|
||||
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
|
||||
|
||||
@@ -100,14 +118,14 @@ impl MistralParser {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
if let Value::Array(arr) = value {
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
for item in arr.iter() {
|
||||
if let Some(tool) = self.parse_single_object(item)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single object case (shouldn't happen with Mistral format, but handle it)
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
if let Some(tool) = self.parse_single_object(&value)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
@@ -116,7 +134,7 @@ impl MistralParser {
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
@@ -128,8 +146,12 @@ impl MistralParser {
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("mistral_call_{}", index);
|
||||
// Generate unique ID
|
||||
let id = obj
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4()));
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
@@ -188,95 +210,57 @@ impl ToolParser for MistralParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Append new text to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
// Check if current_text has tool_call
|
||||
let has_tool_start = self.has_tool_markers(current_text)
|
||||
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|
||||
|
||||
// Check for text before [TOOL_CALLS] and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("[TOOL_CALLS]") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
if !has_tool_start {
|
||||
// Only clear buffer if we're sure no tool call is starting
|
||||
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
|
||||
let normal_text = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
} else {
|
||||
// Might be partial bot_token, keep buffering
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Try to extract complete JSON array
|
||||
if let Some(json_array) = self.extract_json_array(&state.buffer) {
|
||||
// Parse with partial JSON to handle incomplete content
|
||||
match self.partial_json.parse_value(json_array) {
|
||||
Ok((value, consumed)) => {
|
||||
// Check if we have a complete JSON structure
|
||||
if consumed == json_array.len() {
|
||||
// Complete JSON, parse tool calls
|
||||
let tools = if let Value::Array(arr) = value {
|
||||
let mut result = Vec::new();
|
||||
for (index, item) in arr.iter().enumerate() {
|
||||
if let Some(tool) = self.parse_single_object(item, index)? {
|
||||
result.push(tool);
|
||||
}
|
||||
}
|
||||
result
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
// Build tool indices
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
if !tools.is_empty() {
|
||||
// Clear buffer since we consumed everything
|
||||
state.buffer.clear();
|
||||
// Determine start index for JSON parsing
|
||||
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
|
||||
pos + self.bot_token.len()
|
||||
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
|
||||
self.tool_call_separator.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Return the first tool (simplified for Phase 3)
|
||||
// Full multi-tool streaming will be implemented later
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Partial JSON - try to extract tool name for streaming
|
||||
if let Value::Array(arr) = value {
|
||||
if let Some(first_tool) = arr.first() {
|
||||
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
|
||||
{
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = first_tool.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
helpers::handle_json_tool_streaming(
|
||||
current_text,
|
||||
start_idx,
|
||||
&mut self.partial_json,
|
||||
&tool_indices,
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_id,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&mut self.prev_tool_call_arr,
|
||||
)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -15,6 +15,9 @@ pub mod pythonic_parser;
|
||||
pub mod qwen_parser;
|
||||
pub mod step3_parser;
|
||||
|
||||
// Shared helpers and utilities
|
||||
pub mod helpers;
|
||||
|
||||
// Re-export parser types for convenience
|
||||
pub use deepseek_parser::DeepSeekParser;
|
||||
pub use glm4_moe_parser::Glm4MoeParser;
|
||||
|
||||
@@ -15,11 +15,13 @@ use rustpython_parser::{parse, Mode};
|
||||
use serde_json::{Map, Number, Value};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
@@ -37,13 +39,23 @@ fn pythonic_block_regex() -> &'static Regex {
|
||||
}
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
#[derive(Default)]
|
||||
pub struct PythonicParser;
|
||||
pub struct PythonicParser {
|
||||
/// Buffer for accumulating chunks
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl Default for PythonicParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PythonicParser {
|
||||
/// Create a new Pythonic parser
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first pythonic tool call block and return it along with the
|
||||
@@ -105,23 +117,90 @@ impl ToolParser for PythonicParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
|
||||
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
let cleaned = Self::strip_special_tokens(&self.buffer);
|
||||
|
||||
// Look for opening bracket
|
||||
if let Some(start) = cleaned.find('[') {
|
||||
let normal_text = if start > 0 {
|
||||
cleaned[..start].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Look for matching closing bracket
|
||||
if let Some(end) = find_matching_bracket(&cleaned, start) {
|
||||
// Found complete tool call - extract it and parse using parse_complete
|
||||
let call_text = &cleaned[start..=end];
|
||||
|
||||
match self.parse_complete(call_text).await {
|
||||
Ok((_, calls)) => {
|
||||
// Update buffer with remaining text after tool call
|
||||
let remaining_text = &cleaned[end + 1..];
|
||||
self.buffer = remaining_text.to_string();
|
||||
|
||||
// Validate tool names and convert ToolCall to ToolCallItem
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
let items: Vec<ToolCallItem> = calls
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, tool)| {
|
||||
if !tool_indices.contains_key(&tool.function.name) {
|
||||
tracing::warn!(
|
||||
"Invalid tool name '{}' - skipping",
|
||||
tool.function.name
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ToolCallItem {
|
||||
tool_index: idx,
|
||||
name: Some(tool.function.name),
|
||||
parameters: tool.function.arguments,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: items,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse pythonic tool call: {}", e);
|
||||
// Clear buffer on error
|
||||
self.buffer.clear();
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We have an opening bracket but no closing bracket yet
|
||||
// Put back everything from the bracket onwards
|
||||
self.buffer = cleaned[start..].to_string();
|
||||
|
||||
if !normal_text.is_empty() {
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Still accumulating a potential tool call
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
// No tool call bracket found
|
||||
self.buffer.clear();
|
||||
Ok(StreamingParseResult {
|
||||
normal_text: cleaned,
|
||||
calls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
@@ -134,6 +213,25 @@ impl ToolParser for PythonicParser {
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the matching closing bracket for the opening bracket at start position.
|
||||
/// Properly handles nested brackets.
|
||||
fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
|
||||
let mut bracket_count = 0;
|
||||
let chars: Vec<char> = buffer.chars().collect();
|
||||
|
||||
for (i, &ch) in chars.iter().enumerate().skip(start) {
|
||||
if ch == '[' {
|
||||
bracket_count += 1;
|
||||
} else if ch == ']' {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
None // No matching bracket found
|
||||
}
|
||||
|
||||
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
|
||||
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
|
||||
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
|
||||
|
||||
@@ -2,12 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Qwen format parser for tool calls
|
||||
@@ -19,11 +21,36 @@ use crate::tool_parser::{
|
||||
/// - XML-style tags with JSON content
|
||||
/// - Support for multiple sequential tool calls
|
||||
/// - Newline-aware parsing
|
||||
/// - Buffering for partial end tokens
|
||||
pub struct QwenParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting tool calls
|
||||
|
||||
/// Regex for extracting tool calls in parse_complete
|
||||
extractor: Regex,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Buffer for normal text that might precede partial end tokens
|
||||
normal_text_buffer: String,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
eot_token: &'static str,
|
||||
tool_call_separator: &'static str,
|
||||
}
|
||||
|
||||
impl QwenParser {
|
||||
@@ -36,11 +63,20 @@ impl QwenParser {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
extractor,
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
normal_text_buffer: String::new(),
|
||||
bot_token: "<tool_call>\n",
|
||||
eot_token: "\n</tool_call>",
|
||||
tool_call_separator: "\n",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
@@ -52,8 +88,12 @@ impl QwenParser {
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("qwen_call_{}", index);
|
||||
// Generate unique ID
|
||||
let id = obj
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
@@ -73,42 +113,9 @@ impl QwenParser {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
|
||||
/// Find the start position of a tool call
|
||||
fn find_tool_start(&self, text: &str) -> Option<usize> {
|
||||
text.find("<tool_call>\n")
|
||||
}
|
||||
|
||||
/// Find the end position of a tool call
|
||||
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
|
||||
let search_from = start_pos + "<tool_call>\n".len();
|
||||
text[search_from..]
|
||||
.find("\n</tool_call>")
|
||||
.map(|pos| search_from + pos + "\n</tool_call>".len())
|
||||
}
|
||||
|
||||
/// Check if buffer ends with a partial token
|
||||
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
|
||||
// Check for partial start token
|
||||
let start_token = "<tool_call>\n";
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
for i in 1..=start_token.len().min(buffer.len()) {
|
||||
if start_token.starts_with(&buffer[buffer.len() - i..]) {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial end token
|
||||
let end_token = "\n</tool_call>";
|
||||
// Only check if buffer ends with a partial match (not the complete token without newline)
|
||||
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
|
||||
if buffer.ends_with("</tool_call>") {
|
||||
// This is a complete end tag, just missing the leading newline
|
||||
// Not a partial token situation
|
||||
return None;
|
||||
}
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
(1..=end_token.len().min(buffer.len()))
|
||||
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
|
||||
/// Check if text has tool call
|
||||
fn has_tool_call(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,17 +139,17 @@ impl ToolParser for QwenParser {
|
||||
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for (index, captures) in self.extractor.captures_iter(text).enumerate() {
|
||||
for captures in self.extractor.captures_iter(text) {
|
||||
if let Some(json_str) = captures.get(1) {
|
||||
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
|
||||
.and_then(|v| self.parse_single_object(&v, index));
|
||||
.and_then(|v| self.parse_single_object(&v));
|
||||
|
||||
match parsed {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call {}: {:?}", index, e);
|
||||
tracing::warn!("Failed to parse tool call: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -158,103 +165,91 @@ impl ToolParser for QwenParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Append new text to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check for partial token at end of buffer
|
||||
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
|
||||
// Hold back the partial token
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
// Check if current_text has tool_call
|
||||
let has_tool_start = self.has_tool_call(current_text)
|
||||
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
if !has_tool_start {
|
||||
// Only clear buffer if we're sure no tool call is starting
|
||||
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
|
||||
let normal_text = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("<tool_call>") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
}
|
||||
|
||||
// Find start and end positions
|
||||
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
|
||||
// Check if we have the complete tool call
|
||||
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
|
||||
// Extract the JSON content
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let json_end = end_pos - "\n</tool_call>".len();
|
||||
let json_str = &state.buffer[json_start..json_end];
|
||||
|
||||
// Parse the complete JSON
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
// Clear the consumed part from buffer using drain for efficiency
|
||||
state.buffer.drain(..end_pos);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete or malformed
|
||||
// If we have what looks like a complete tool call block, treat as normal text
|
||||
if state.buffer[start_pos..end_pos].contains("\n</tool_call>") {
|
||||
let malformed_text: String = state.buffer.drain(..end_pos).collect();
|
||||
return Ok(StreamResult::NormalText(malformed_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
} else {
|
||||
// We have start but no end yet - try partial parsing
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let partial_json = &state.buffer[json_start..];
|
||||
// Might be partial bot_token, keep buffering
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing newline if present (might be start of end token)
|
||||
let partial_json = partial_json.trim_end();
|
||||
// Build tool indices
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(partial_json) {
|
||||
Ok((value, _consumed)) => {
|
||||
// Extract tool name if available
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
// Determine start index for JSON parsing
|
||||
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
|
||||
pos + self.bot_token.len()
|
||||
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
|
||||
self.tool_call_separator.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = value.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
let mut result = helpers::handle_json_tool_streaming(
|
||||
current_text,
|
||||
start_idx,
|
||||
&mut self.partial_json,
|
||||
&tool_indices,
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_id,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&mut self.prev_tool_call_arr,
|
||||
)?;
|
||||
|
||||
// Qwen-specific: Handle partial end tokens in normal text
|
||||
// After tool calls complete, normal text might contain partial "</tool_call>" tags
|
||||
if !result.normal_text.is_empty() {
|
||||
self.normal_text_buffer.push_str(&result.normal_text);
|
||||
|
||||
// Check if buffer contains complete end token (without leading newline)
|
||||
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
|
||||
if self.normal_text_buffer.contains(end_token_without_newline) {
|
||||
// Complete end token found - clean it and return
|
||||
let cleaned_text = self
|
||||
.normal_text_buffer
|
||||
.replace(end_token_without_newline, "");
|
||||
self.normal_text_buffer.clear();
|
||||
result.normal_text = cleaned_text;
|
||||
} else {
|
||||
// Check if buffer might contain partial end token at the end
|
||||
if let Some(partial_match_len) = helpers::ends_with_partial_token(
|
||||
&self.normal_text_buffer,
|
||||
end_token_without_newline,
|
||||
) {
|
||||
// Keep potential partial match in buffer, return the rest
|
||||
let split_point = self.normal_text_buffer.len() - partial_match_len;
|
||||
result.normal_text = self.normal_text_buffer[..split_point].to_string();
|
||||
self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
|
||||
} else {
|
||||
// No partial match, return all buffered text
|
||||
result.normal_text = self.normal_text_buffer.clone();
|
||||
self.normal_text_buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
/// Step3 format parser for tool calls
|
||||
@@ -25,6 +28,29 @@ pub struct Step3Parser {
|
||||
invoke_extractor: Regex,
|
||||
/// Regex for extracting parameters
|
||||
param_extractor: Regex,
|
||||
|
||||
/// Buffer for accumulating chunks
|
||||
buffer: String,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
eot_token: &'static str,
|
||||
tool_call_begin: &'static str,
|
||||
tool_call_end: &'static str,
|
||||
tool_sep: &'static str,
|
||||
|
||||
/// Streaming state variables (mirrors Python's Step3Detector)
|
||||
in_tool_block: bool,
|
||||
tool_block_finished: bool,
|
||||
current_function_name: String,
|
||||
current_parameters: serde_json::Map<String, Value>,
|
||||
in_tool_call: bool,
|
||||
function_name_sent: bool,
|
||||
|
||||
/// Standard state machine fields
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
current_tool_id: i32,
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
}
|
||||
|
||||
impl Step3Parser {
|
||||
@@ -46,12 +72,254 @@ impl Step3Parser {
|
||||
tool_call_extractor,
|
||||
invoke_extractor,
|
||||
param_extractor,
|
||||
|
||||
buffer: String::new(),
|
||||
|
||||
bot_token: "<|tool_calls_begin|>",
|
||||
eot_token: "<|tool_calls_end|>",
|
||||
tool_call_begin: "<|tool_call_begin|>",
|
||||
tool_call_end: "<|tool_call_end|>",
|
||||
tool_sep: "<|tool_sep|>",
|
||||
|
||||
// Streaming state variables
|
||||
in_tool_block: false,
|
||||
tool_block_finished: false,
|
||||
current_function_name: String::new(),
|
||||
current_parameters: serde_json::Map::new(),
|
||||
in_tool_call: false,
|
||||
function_name_sent: false,
|
||||
|
||||
// Standard state machine fields
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text contains Step3 tool markers
|
||||
fn has_tool_markers(&self, text: &str) -> bool {
|
||||
text.contains("<|tool_calls_begin|>")
|
||||
text.contains(self.bot_token)
|
||||
}
|
||||
|
||||
/// Reset streaming state for the next tool call
|
||||
fn reset_streaming_state(&mut self) {
|
||||
self.in_tool_call = false;
|
||||
self.function_name_sent = false;
|
||||
self.current_function_name.clear();
|
||||
self.current_parameters.clear();
|
||||
}
|
||||
|
||||
/// Parse partial tool call for streaming scenarios (mirrors Python's _parse_partial_tool_call)
|
||||
fn parse_partial_tool_call(
|
||||
&mut self,
|
||||
tool_indices: &HashMap<String, usize>,
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
let mut calls = Vec::new();
|
||||
|
||||
// Check if we have tool_sep (means we're past the type declaration)
|
||||
if !self.buffer.contains(self.tool_sep) {
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
});
|
||||
}
|
||||
|
||||
// Clone the buffer to avoid borrow conflicts
|
||||
let buffer_clone = self.buffer.clone();
|
||||
let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
|
||||
if parts.len() != 2 {
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
});
|
||||
}
|
||||
|
||||
let type_part = parts[0].trim();
|
||||
let invoke_part = parts[1];
|
||||
|
||||
// Check if it's a function type
|
||||
if type_part != "function" {
|
||||
// Invalid tool type, skip this tool call
|
||||
self.reset_streaming_state();
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract function name if not sent yet
|
||||
if !self.function_name_sent {
|
||||
if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
|
||||
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Validate function name
|
||||
if tool_indices.contains_key(func_name) {
|
||||
self.current_function_name = func_name.to_string();
|
||||
self.function_name_sent = true;
|
||||
|
||||
// Initialize tool tracking
|
||||
if self.current_tool_id == -1 {
|
||||
self.current_tool_id = 0;
|
||||
}
|
||||
|
||||
// Ensure tracking arrays are large enough
|
||||
helpers::ensure_capacity(
|
||||
self.current_tool_id,
|
||||
&mut self.prev_tool_call_arr,
|
||||
&mut self.streamed_args_for_tool,
|
||||
);
|
||||
|
||||
// Store tool call info
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
self.prev_tool_call_arr[tool_id] = serde_json::json!({
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
});
|
||||
|
||||
// Send tool name with empty parameters
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: Some(func_name.to_string()),
|
||||
parameters: String::new(),
|
||||
});
|
||||
} else {
|
||||
// Invalid function name
|
||||
tracing::warn!("Invalid function name: {}", func_name);
|
||||
self.reset_streaming_state();
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Function name not complete yet
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parameters incrementally
|
||||
if self.function_name_sent {
|
||||
// Extract all complete parameters
|
||||
let mut new_params = serde_json::Map::new();
|
||||
for capture in self.param_extractor.captures_iter(invoke_part) {
|
||||
let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
|
||||
let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Try to parse the value as JSON first, fallback to string
|
||||
let param_value =
|
||||
if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
|
||||
json_val
|
||||
} else {
|
||||
// Try parsing as Python literal
|
||||
if param_value_str == "true" || param_value_str == "True" {
|
||||
Value::Bool(true)
|
||||
} else if param_value_str == "false" || param_value_str == "False" {
|
||||
Value::Bool(false)
|
||||
} else if param_value_str == "null" || param_value_str == "None" {
|
||||
Value::Null
|
||||
} else if let Ok(num) = param_value_str.parse::<i64>() {
|
||||
Value::Number(num.into())
|
||||
} else if let Ok(num) = param_value_str.parse::<f64>() {
|
||||
if let Some(n) = serde_json::Number::from_f64(num) {
|
||||
Value::Number(n)
|
||||
} else {
|
||||
Value::String(param_value_str.to_string())
|
||||
}
|
||||
} else {
|
||||
Value::String(param_value_str.to_string())
|
||||
}
|
||||
};
|
||||
|
||||
new_params.insert(param_name.to_string(), param_value);
|
||||
}
|
||||
|
||||
// Check if we have new parameters to stream
|
||||
if new_params != self.current_parameters {
|
||||
// Build the JSON content without the closing brace for streaming
|
||||
let diff = if self.current_parameters.is_empty() {
|
||||
// First parameters - send opening brace and content
|
||||
let params_content =
|
||||
serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
|
||||
if params_content.len() > 2 {
|
||||
// Send everything except the closing brace
|
||||
params_content[..params_content.len() - 1].to_string()
|
||||
} else {
|
||||
"{".to_string()
|
||||
}
|
||||
} else {
|
||||
// Subsequent parameters - calculate the incremental diff
|
||||
let old_json = serde_json::to_string(&self.current_parameters)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
let new_json =
|
||||
serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
// Remove closing braces for comparison
|
||||
let old_without_brace = &old_json[..old_json.len() - 1];
|
||||
let new_without_brace = &new_json[..new_json.len() - 1];
|
||||
|
||||
// The new content should extend the old content
|
||||
new_without_brace
|
||||
.strip_prefix(old_without_brace)
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
if !diff.is_empty() {
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: None,
|
||||
parameters: diff.clone(),
|
||||
});
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.streamed_args_for_tool.len() {
|
||||
self.streamed_args_for_tool[tool_id].push_str(&diff);
|
||||
}
|
||||
}
|
||||
|
||||
// Update current state
|
||||
self.current_parameters = new_params.clone();
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.prev_tool_call_arr.len() {
|
||||
if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
|
||||
obj.insert("arguments".to_string(), Value::Object(new_params));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool call is complete
|
||||
if self.buffer.contains(self.tool_call_end) {
|
||||
// Send closing brace if we've sent any parameters
|
||||
let tool_id = self.current_tool_id as usize;
|
||||
if tool_id < self.streamed_args_for_tool.len()
|
||||
&& !self.streamed_args_for_tool[tool_id].is_empty()
|
||||
{
|
||||
calls.push(ToolCallItem {
|
||||
tool_index: self.current_tool_id as usize,
|
||||
name: None,
|
||||
parameters: "}".to_string(),
|
||||
});
|
||||
self.streamed_args_for_tool[tool_id].push('}');
|
||||
}
|
||||
|
||||
// Find the end position
|
||||
if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
|
||||
// Remove the processed tool call from buffer
|
||||
self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
|
||||
}
|
||||
|
||||
// Reset state for next tool call
|
||||
self.reset_streaming_state();
|
||||
self.current_tool_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamingParseResult {
|
||||
normal_text: String::new(),
|
||||
calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse parameters from steptml format
|
||||
@@ -188,96 +456,106 @@ impl ToolParser for Step3Parser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Check for tool markers
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
// Build tool indices for validation
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
// Stage 1: If we've finished the tool block, everything is normal text
|
||||
if self.tool_block_finished {
|
||||
let normal_text = std::mem::take(&mut self.buffer);
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("<|tool_calls_begin|>") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
}
|
||||
|
||||
// Look for start of tool calls
|
||||
if let Some(start_pos) = state.buffer.find("<|tool_calls_begin|>") {
|
||||
let search_from = start_pos + "<|tool_calls_begin|>".len();
|
||||
|
||||
// Look for individual tool call start
|
||||
if let Some(call_start) = state.buffer[search_from..].find("<|tool_call_begin|>") {
|
||||
let call_start_abs = search_from + call_start;
|
||||
|
||||
// Look for the end of this tool call
|
||||
let search_end_from = call_start_abs + "<|tool_call_begin|>".len();
|
||||
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool_call_end|>")
|
||||
{
|
||||
let call_end_abs = search_end_from + call_end + "<|tool_call_end|>".len();
|
||||
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
// Stage 2: Check if tool block hasn't started yet
|
||||
if !self.in_tool_block {
|
||||
if self.buffer.contains(self.bot_token) {
|
||||
let idx = self.buffer.find(self.bot_token).unwrap();
|
||||
let normal_text = self.buffer[..idx].to_string();
|
||||
self.buffer = self.buffer[idx + self.bot_token.len()..].to_string();
|
||||
self.in_tool_block = true;
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
} else {
|
||||
// Check if we might have a partial bot_token
|
||||
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_some() {
|
||||
return Ok(StreamingParseResult::default()); // Wait for more text
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
let partial = &state.buffer[search_end_from..];
|
||||
|
||||
// Check for tool separator
|
||||
if let Some(sep_pos) = partial.find("<|tool_sep|>") {
|
||||
// Check if it's a function
|
||||
if partial[..sep_pos].contains("function") {
|
||||
let after_sep = &partial[sep_pos + "<|tool_sep|>".len()..];
|
||||
|
||||
// Try to extract function name from steptml:invoke
|
||||
if let Some(name_match) = self.invoke_extractor.captures(after_sep) {
|
||||
let func_name = name_match.get(1).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
if !state.in_string && !func_name.is_empty() {
|
||||
state.in_string = true; // Mark name as sent
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: func_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Try to extract partial parameters
|
||||
if let Some(params_text) = name_match.get(2) {
|
||||
let parameters =
|
||||
self.parse_steptml_parameters(params_text.as_str())?;
|
||||
|
||||
if !parameters.is_empty() {
|
||||
let args_str = serde_json::to_string(¶meters)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let normal_text = std::mem::take(&mut self.buffer);
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
// We're inside the tool block
|
||||
let mut calls = Vec::new();
|
||||
|
||||
// Stage 3: Check if tool block is ending
|
||||
if self.buffer.contains(self.eot_token) {
|
||||
let idx = self.buffer.find(self.eot_token).unwrap();
|
||||
|
||||
// If we're in the middle of a tool call, we need to handle it
|
||||
if self.in_tool_call {
|
||||
// The buffer before eot_token might contain the end of the current tool call
|
||||
let before_eot = &self.buffer[..idx];
|
||||
if before_eot.contains(self.tool_call_end) {
|
||||
// Parse this final tool call
|
||||
let result = self.parse_partial_tool_call(&tool_indices)?;
|
||||
calls.extend(result.calls);
|
||||
} else {
|
||||
// Incomplete tool call - log warning
|
||||
tracing::warn!("Tool block ended with incomplete tool call");
|
||||
}
|
||||
}
|
||||
|
||||
let remaining = self.buffer[idx + self.eot_token.len()..].to_string();
|
||||
self.buffer.clear();
|
||||
self.tool_block_finished = true;
|
||||
|
||||
// Reset any partial tool call state
|
||||
self.reset_streaming_state();
|
||||
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text: remaining,
|
||||
calls,
|
||||
});
|
||||
}
|
||||
|
||||
// Stage 4: Check if we're in a tool call or need to start one
|
||||
if !self.in_tool_call {
|
||||
if self.buffer.contains(self.tool_call_begin) {
|
||||
let idx = self.buffer.find(self.tool_call_begin).unwrap();
|
||||
// Remove any content before tool call begin (shouldn't happen but be safe)
|
||||
self.buffer = self.buffer[idx + self.tool_call_begin.len()..].to_string();
|
||||
self.in_tool_call = true;
|
||||
self.function_name_sent = false;
|
||||
self.current_function_name.clear();
|
||||
self.current_parameters.clear();
|
||||
// Fall through to parse the partial tool call
|
||||
} else {
|
||||
// Wait for tool call to begin
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 5: Parse partial tool call
|
||||
if self.in_tool_call {
|
||||
return self.parse_partial_tool_call(&tool_indices);
|
||||
}
|
||||
|
||||
Ok(StreamingParseResult::default())
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
use crate::tool_parser::parsers::{
|
||||
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
|
||||
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
|
||||
/// Global singleton registry instance - created once and reused
|
||||
pub static GLOBAL_REGISTRY: Lazy<ParserRegistry> = Lazy::new(ParserRegistry::new_internal);
|
||||
|
||||
/// Registry for tool parsers and model mappings
|
||||
pub struct ParserRegistry {
|
||||
/// Map of parser name to parser instance
|
||||
parsers: HashMap<String, Arc<dyn ToolParser>>,
|
||||
/// Map of model name/pattern to parser name
|
||||
model_mapping: HashMap<String, String>,
|
||||
/// Default parser to use when no match found
|
||||
default_parser: String,
|
||||
}
|
||||
|
||||
impl ParserRegistry {
|
||||
/// Get the global singleton instance
|
||||
pub fn new() -> &'static Self {
|
||||
&GLOBAL_REGISTRY
|
||||
}
|
||||
|
||||
/// Create a new instance for testing (not the singleton)
|
||||
#[cfg(test)]
|
||||
pub fn new_for_testing() -> Self {
|
||||
Self::new_internal()
|
||||
}
|
||||
|
||||
/// Internal constructor for creating the singleton instance
|
||||
fn new_internal() -> Self {
|
||||
let mut registry = Self {
|
||||
parsers: HashMap::new(),
|
||||
model_mapping: HashMap::new(),
|
||||
default_parser: "json".to_string(),
|
||||
};
|
||||
|
||||
// Register default parsers
|
||||
registry.register_default_parsers();
|
||||
|
||||
// Register default model mappings
|
||||
registry.register_default_mappings();
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
/// Register a parser
|
||||
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
|
||||
self.parsers.insert(name.into(), parser);
|
||||
}
|
||||
|
||||
/// Map a model name/pattern to a parser
|
||||
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
|
||||
self.model_mapping.insert(model.into(), parser.into());
|
||||
}
|
||||
|
||||
/// Get parser for a specific model
|
||||
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
|
||||
// Try exact match first
|
||||
if let Some(parser_name) = self.model_mapping.get(model) {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching with more specific patterns first
|
||||
// Collect all matching patterns and sort by specificity (longer = more specific)
|
||||
let mut matches: Vec<(&String, &String)> = self
|
||||
.model_mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
model.starts_with(prefix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by pattern length in descending order (longer patterns are more specific)
|
||||
matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));
|
||||
|
||||
// Return the first matching parser
|
||||
for (_, parser_name) in matches {
|
||||
if let Some(parser) = self.parsers.get(parser_name) {
|
||||
return Some(parser.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default parser if it exists
|
||||
self.parsers.get(&self.default_parser).cloned()
|
||||
}
|
||||
|
||||
/// List all registered parsers
|
||||
pub fn list_parsers(&self) -> Vec<&str> {
|
||||
self.parsers.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// List all model mappings
|
||||
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
|
||||
self.model_mapping
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Register default parsers
|
||||
fn register_default_parsers(&mut self) {
|
||||
// JSON parser - most common format
|
||||
self.register_parser("json", Arc::new(JsonParser::new()));
|
||||
|
||||
// Mistral parser - [TOOL_CALLS] [...] format
|
||||
self.register_parser("mistral", Arc::new(MistralParser::new()));
|
||||
|
||||
// Qwen parser - <tool_call>...</tool_call> format
|
||||
self.register_parser("qwen", Arc::new(QwenParser::new()));
|
||||
|
||||
// Pythonic parser - [func(arg=val)] format
|
||||
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
|
||||
|
||||
// Llama parser - <|python_tag|>{...} or plain JSON format
|
||||
self.register_parser("llama", Arc::new(LlamaParser::new()));
|
||||
|
||||
// DeepSeek V3 parser - Unicode tokens with JSON blocks
|
||||
self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
|
||||
|
||||
// GLM-4 MoE parser - XML-style key-value format
|
||||
self.register_parser("glm4_moe", Arc::new(Glm4MoeParser::new()));
|
||||
|
||||
// Step3 parser - StepTML XML format
|
||||
self.register_parser("step3", Arc::new(Step3Parser::new()));
|
||||
|
||||
// Kimi K2 parser - Token-based with indexed functions
|
||||
self.register_parser("kimik2", Arc::new(KimiK2Parser::new()));
|
||||
|
||||
// GPT-OSS parsers - register legacy and Harmony variants
|
||||
let gpt_oss_legacy = Arc::new(GptOssParser::new());
|
||||
let gpt_oss_harmony = Arc::new(GptOssHarmonyParser::new());
|
||||
|
||||
self.register_parser("gpt_oss_legacy", gpt_oss_legacy.clone());
|
||||
self.register_parser("gpt_oss_harmony", gpt_oss_harmony.clone());
|
||||
|
||||
if use_harmony_gpt_oss() {
|
||||
self.register_parser("gpt_oss", gpt_oss_harmony);
|
||||
} else {
|
||||
self.register_parser("gpt_oss", gpt_oss_legacy);
|
||||
}
|
||||
}
|
||||
|
||||
/// Register default model mappings
|
||||
fn register_default_mappings(&mut self) {
|
||||
// OpenAI models
|
||||
self.map_model("gpt-4*", "json");
|
||||
self.map_model("gpt-3.5*", "json");
|
||||
self.map_model("gpt-4o*", "json");
|
||||
|
||||
// Anthropic models
|
||||
self.map_model("claude-*", "json");
|
||||
|
||||
// Mistral models - use Mistral parser
|
||||
self.map_model("mistral-*", "mistral");
|
||||
self.map_model("mixtral-*", "mistral");
|
||||
|
||||
// Qwen models - use Qwen parser
|
||||
self.map_model("qwen*", "qwen");
|
||||
self.map_model("Qwen*", "qwen");
|
||||
|
||||
// Llama models
|
||||
// Llama 4 uses pythonic format
|
||||
self.map_model("llama-4*", "pythonic");
|
||||
self.map_model("meta-llama-4*", "pythonic");
|
||||
// Llama 3.2 uses python_tag format
|
||||
self.map_model("llama-3.2*", "llama");
|
||||
self.map_model("meta-llama-3.2*", "llama");
|
||||
// Other Llama models use JSON
|
||||
self.map_model("llama-*", "json");
|
||||
self.map_model("meta-llama-*", "json");
|
||||
|
||||
// DeepSeek models
|
||||
// DeepSeek V3 uses custom Unicode token format
|
||||
self.map_model("deepseek-v3*", "deepseek");
|
||||
self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
|
||||
// DeepSeek V2 uses pythonic format
|
||||
self.map_model("deepseek-*", "pythonic");
|
||||
|
||||
// GLM models
|
||||
// GLM-4.5 and GLM-4.6 uses XML-style format
|
||||
self.map_model("glm-4.5*", "glm4_moe");
|
||||
self.map_model("glm-4.6*", "glm4_moe");
|
||||
// Other GLM models may use JSON
|
||||
self.map_model("glm-*", "json");
|
||||
|
||||
// Step3 models
|
||||
self.map_model("step3*", "step3");
|
||||
self.map_model("Step-3*", "step3");
|
||||
|
||||
// Kimi models
|
||||
self.map_model("kimi-k2*", "kimik2");
|
||||
self.map_model("Kimi-K2*", "kimik2");
|
||||
self.map_model("moonshot*/Kimi-K2*", "kimik2");
|
||||
|
||||
// GPT-OSS models (T4-style)
|
||||
self.map_model("gpt-oss*", "gpt_oss");
|
||||
self.map_model("t4-*", "gpt_oss");
|
||||
|
||||
// Other models default to JSON
|
||||
self.map_model("gemini-*", "json");
|
||||
self.map_model("palm-*", "json");
|
||||
self.map_model("gemma-*", "json");
|
||||
}
|
||||
|
||||
/// Set the default parser
|
||||
pub fn set_default_parser(&mut self, name: impl Into<String>) {
|
||||
self.default_parser = name.into();
|
||||
}
|
||||
|
||||
/// Check if a parser is registered
|
||||
pub fn has_parser(&self, name: &str) -> bool {
|
||||
self.parsers.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
fn use_harmony_gpt_oss() -> bool {
|
||||
env::var("ROUTER_USE_HARMONY_GPT_OSS")
|
||||
.ok()
|
||||
.map(|value| {
|
||||
let normalized = value.trim();
|
||||
matches!(
|
||||
normalized,
|
||||
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
impl Default for &'static ParserRegistry {
|
||||
fn default() -> Self {
|
||||
ParserRegistry::new()
|
||||
}
|
||||
}
|
||||
@@ -1,189 +1,3 @@
|
||||
use crate::tool_parser::types::{PartialToolCall, ToolCall};
|
||||
|
||||
/// Current phase of parsing
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ParsePhase {
|
||||
/// Looking for start of tool call
|
||||
Searching,
|
||||
/// Parsing function name
|
||||
InName,
|
||||
/// Parsing function arguments
|
||||
InArguments,
|
||||
/// Tool call complete
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// State for streaming parser
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParseState {
|
||||
/// Buffer for accumulating input
|
||||
pub buffer: String,
|
||||
/// Position of last consumed character
|
||||
pub consumed: usize,
|
||||
/// Current partial tool being parsed
|
||||
pub partial_tool: Option<PartialToolCall>,
|
||||
/// Completed tool calls
|
||||
pub completed_tools: Vec<ToolCall>,
|
||||
/// Current parsing phase
|
||||
pub phase: ParsePhase,
|
||||
/// Bracket/brace depth for JSON parsing
|
||||
pub bracket_depth: i32,
|
||||
/// Whether currently inside a string literal
|
||||
pub in_string: bool,
|
||||
/// Whether next character should be escaped
|
||||
pub escape_next: bool,
|
||||
/// Current tool index (for streaming)
|
||||
pub tool_index: usize,
|
||||
/// Optional Harmony-specific streaming state (populated by token-aware parsers)
|
||||
pub harmony_stream: Option<HarmonyStreamState>,
|
||||
}
|
||||
|
||||
impl ParseState {
|
||||
/// Create a new parse state
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
consumed: 0,
|
||||
partial_tool: None,
|
||||
completed_tools: Vec::new(),
|
||||
phase: ParsePhase::Searching,
|
||||
bracket_depth: 0,
|
||||
in_string: false,
|
||||
escape_next: false,
|
||||
tool_index: 0,
|
||||
harmony_stream: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset state for parsing next tool
|
||||
pub fn reset(&mut self) {
|
||||
self.partial_tool = None;
|
||||
self.phase = ParsePhase::Searching;
|
||||
self.bracket_depth = 0;
|
||||
self.in_string = false;
|
||||
self.escape_next = false;
|
||||
self.harmony_stream = None;
|
||||
}
|
||||
|
||||
/// Process a single character for JSON parsing
|
||||
pub fn process_char(&mut self, ch: char) {
|
||||
// Handle escape sequences
|
||||
if self.escape_next {
|
||||
self.escape_next = false;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
if ch == '\\' && self.in_string {
|
||||
self.escape_next = true;
|
||||
self.buffer.push(ch);
|
||||
return;
|
||||
}
|
||||
|
||||
// Track string boundaries
|
||||
if ch == '"' && !self.escape_next {
|
||||
self.in_string = !self.in_string;
|
||||
}
|
||||
|
||||
// Track bracket depth for JSON
|
||||
if !self.in_string {
|
||||
match ch {
|
||||
'{' | '[' => {
|
||||
self.bracket_depth += 1;
|
||||
}
|
||||
'}' | ']' => {
|
||||
self.bracket_depth -= 1;
|
||||
if self.bracket_depth == 0 && self.partial_tool.is_some() {
|
||||
// Complete tool call found
|
||||
self.phase = ParsePhase::Complete;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.push(ch);
|
||||
}
|
||||
|
||||
/// Check if we have a complete JSON object/array
|
||||
pub fn has_complete_json(&self) -> bool {
|
||||
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Extract content from buffer starting at position
|
||||
pub fn extract_from(&self, start: usize) -> &str {
|
||||
if start >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after start
|
||||
let mut safe_start = start;
|
||||
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
|
||||
safe_start += 1;
|
||||
}
|
||||
|
||||
if safe_start < self.buffer.len() {
|
||||
&self.buffer[safe_start..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark content as consumed up to position
|
||||
pub fn consume_to(&mut self, position: usize) {
|
||||
if position > self.consumed {
|
||||
self.consumed = position;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get unconsumed content
|
||||
pub fn unconsumed(&self) -> &str {
|
||||
if self.consumed >= self.buffer.len() {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the nearest character boundary at or after consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed += 1;
|
||||
}
|
||||
|
||||
if safe_consumed < self.buffer.len() {
|
||||
&self.buffer[safe_consumed..]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear consumed content from buffer
|
||||
pub fn clear_consumed(&mut self) {
|
||||
if self.consumed > 0 {
|
||||
// Find the nearest character boundary at or before consumed
|
||||
let mut safe_consumed = self.consumed;
|
||||
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
|
||||
safe_consumed -= 1;
|
||||
}
|
||||
|
||||
if safe_consumed > 0 {
|
||||
self.buffer.drain(..safe_consumed);
|
||||
self.consumed = self.consumed.saturating_sub(safe_consumed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add completed tool
|
||||
pub fn add_completed_tool(&mut self, tool: ToolCall) {
|
||||
self.completed_tools.push(tool);
|
||||
self.tool_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParseState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Placeholder for Harmony streaming metadata captured during token-aware parsing.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HarmonyStreamState {
|
||||
|
||||
@@ -5,64 +5,27 @@ use crate::tool_parser::partial_json::{
|
||||
};
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_new() {
|
||||
let state = ParseState::new();
|
||||
assert_eq!(state.phase, ParsePhase::Searching);
|
||||
assert_eq!(state.buffer, "");
|
||||
assert_eq!(state.consumed, 0);
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
assert!(!state.in_string);
|
||||
assert!(!state.escape_next);
|
||||
#[tokio::test]
|
||||
async fn test_tool_parser_factory() {
|
||||
let factory = ToolParserFactory::new();
|
||||
|
||||
// Test that we can get a pooled parser
|
||||
let pooled_parser = factory.get_pooled("gpt-4");
|
||||
let parser = pooled_parser.lock().await;
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_state_process_char() {
|
||||
let mut state = ParseState::new();
|
||||
#[tokio::test]
|
||||
async fn test_tool_parser_factory_model_mapping() {
|
||||
let factory = ToolParserFactory::new();
|
||||
|
||||
state.process_char('{');
|
||||
assert_eq!(state.bracket_depth, 1);
|
||||
// Test model mapping
|
||||
factory.registry().map_model("test-model", "json");
|
||||
|
||||
state.process_char('}');
|
||||
assert_eq!(state.bracket_depth, 0);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(state.in_string);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.in_string);
|
||||
|
||||
state.process_char('"');
|
||||
state.process_char('\\');
|
||||
assert!(state.escape_next);
|
||||
|
||||
state.process_char('"');
|
||||
assert!(!state.escape_next);
|
||||
assert!(state.in_string); // Still in string because quote was escaped
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
assert!(!registry.list_mappings().is_empty());
|
||||
|
||||
let mappings = registry.list_mappings();
|
||||
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
|
||||
assert!(has_gpt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_registry_pattern_matching() {
|
||||
let mut registry = ParserRegistry::new_for_testing();
|
||||
|
||||
registry.map_model("test-model", "json");
|
||||
|
||||
let mappings = registry.list_mappings();
|
||||
let has_test = mappings
|
||||
.iter()
|
||||
.any(|(m, p)| *m == "test-model" && *p == "json");
|
||||
assert!(has_test);
|
||||
// Get parser for the test model
|
||||
let pooled_parser = factory.get_pooled("test-model");
|
||||
let parser = pooled_parser.lock().await;
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -165,37 +128,7 @@ fn test_compute_diff() {
|
||||
assert_eq!(compute_diff("test", "hello"), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_result_variants() {
|
||||
let result = StreamResult::Incomplete;
|
||||
matches!(result, StreamResult::Incomplete);
|
||||
|
||||
let result = StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: "test".to_string(),
|
||||
};
|
||||
if let StreamResult::ToolName { index, name } = result {
|
||||
assert_eq!(index, 0);
|
||||
assert_eq!(name, "test");
|
||||
} else {
|
||||
panic!("Expected ToolName variant");
|
||||
}
|
||||
|
||||
let tool = ToolCall {
|
||||
id: "123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "test".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let result = StreamResult::ToolComplete(tool.clone());
|
||||
if let StreamResult::ToolComplete(t) = result {
|
||||
assert_eq!(t.id, "123");
|
||||
} else {
|
||||
panic!("Expected ToolComplete variant");
|
||||
}
|
||||
}
|
||||
// NOTE: test_stream_result_variants removed - StreamResult enum replaced by StreamingParseResult
|
||||
|
||||
#[test]
|
||||
fn test_partial_tool_call() {
|
||||
@@ -310,14 +243,12 @@ fn test_json_parser_format_detection() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_registry_with_json_parser() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// JSON parser should be registered by default
|
||||
assert!(registry.has_parser("json"));
|
||||
async fn test_factory_with_json_parser() {
|
||||
let factory = ToolParserFactory::new();
|
||||
|
||||
// Should get JSON parser for OpenAI models
|
||||
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
||||
let pooled_parser = factory.get_pooled("gpt-4-turbo");
|
||||
let parser = pooled_parser.lock().await;
|
||||
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
@@ -546,62 +477,6 @@ mod edge_cases {
|
||||
assert!(tools[0].function.arguments.contains("null"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_partial_chunks() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let mut state1 = ParseState::new();
|
||||
let partial = r#"{"#;
|
||||
let result = parser
|
||||
.parse_incremental(partial, &mut state1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
matches!(result, StreamResult::Incomplete),
|
||||
"Should return Incomplete for just opening brace"
|
||||
);
|
||||
|
||||
let mut state2 = ParseState::new();
|
||||
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
let result = parser
|
||||
.parse_incremental(complete, &mut state2)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "SF");
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON"),
|
||||
}
|
||||
|
||||
// The PartialJson parser can complete partial JSON by filling in missing values
|
||||
let mut state3 = ParseState::new();
|
||||
let partial_with_name = r#"{"name": "test", "argum"#;
|
||||
let result = parser
|
||||
.parse_incremental(partial_with_name, &mut state3)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
// Arguments will be empty object since "argum" is incomplete
|
||||
assert_eq!(tool.function.arguments, "{}");
|
||||
}
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "test");
|
||||
}
|
||||
StreamResult::Incomplete => {
|
||||
// Also acceptable if parser decides to wait
|
||||
}
|
||||
_ => panic!("Unexpected result for partial JSON with name"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_json_values() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::protocols::spec::Tool;
|
||||
use crate::tool_parser::{
|
||||
errors::ToolParserResult,
|
||||
state::ParseState,
|
||||
types::{StreamResult, ToolCall},
|
||||
types::{StreamingParseResult, ToolCall},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
@@ -13,11 +13,16 @@ pub trait ToolParser: Send + Sync {
|
||||
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>;
|
||||
|
||||
/// Parse tool calls from model output (streaming)
|
||||
/// Parsers now maintain internal state, so self is mutable
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `chunk` - New text chunk from model output
|
||||
/// * `tools` - List of available tools for validation
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult>;
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult>;
|
||||
|
||||
/// Check if text contains tool calls in this parser's format
|
||||
fn detect_format(&self, text: &str) -> bool;
|
||||
@@ -50,9 +55,10 @@ pub trait TokenToolParser: ToolParser {
|
||||
) -> ToolParserResult<(String, Vec<ToolCall>)>;
|
||||
|
||||
/// Streaming parser entrypoint for token chunks.
|
||||
/// Parsers maintain internal state, so self is mutable
|
||||
async fn parse_incremental_tokens(
|
||||
&self,
|
||||
&mut self,
|
||||
tokens: &[u32],
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult>;
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult>;
|
||||
}
|
||||
|
||||
@@ -71,3 +71,23 @@ pub struct PartialToolCall {
|
||||
/// Arguments already streamed
|
||||
pub streamed_args: String,
|
||||
}
|
||||
|
||||
/// Result of streaming parse operation (matches Python StreamingParseResult)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StreamingParseResult {
|
||||
/// Normal text that's not part of tool calls
|
||||
pub normal_text: String,
|
||||
/// Tool call items parsed from the chunk
|
||||
pub calls: Vec<ToolCallItem>,
|
||||
}
|
||||
|
||||
/// Simple encapsulation of parsed tool call for streaming (matches Python ToolCallItem)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallItem {
|
||||
/// Tool index in the array
|
||||
pub tool_index: usize,
|
||||
/// Tool name (only present on first chunk)
|
||||
pub name: Option<String>,
|
||||
/// Incremental JSON arguments
|
||||
pub parameters: String,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user