[router][grpc] Support tool call parser in streaming (#11160)

This commit is contained in:
Chang Su
2025-10-02 03:18:50 -07:00
committed by GitHub
parent 5e786cca3a
commit b658be6f6a
38 changed files with 3086 additions and 2245 deletions

View File

@@ -0,0 +1,319 @@
// Factory and pool for creating model-specific tool parsers with pooling support.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
};
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances.
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
/// Registry for model-specific tool parsers with pooling support.
#[derive(Clone)]
pub struct ToolParserRegistry {
/// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
/// Model pattern to parser name mappings
model_mapping: Arc<RwLock<HashMap<String, String>>>,
/// Default parser name
default_parser: Arc<RwLock<String>>,
}
impl ToolParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
creators: Arc::new(RwLock::new(HashMap::new())),
pool: Arc::new(RwLock::new(HashMap::new())),
model_mapping: Arc::new(RwLock::new(HashMap::new())),
default_parser: Arc::new(RwLock::new("json".to_string())),
}
}
/// Register a parser creator for a given parser type.
pub fn register_parser<F>(&self, name: &str, creator: F)
where
F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
{
let mut creators = self.creators.write().unwrap();
creators.insert(name.to_string(), Arc::new(creator));
}
/// Map a model name/pattern to a parser
pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
let mut mapping = self.model_mapping.write().unwrap();
mapping.insert(model.into(), parser.into());
}
/// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
if let Some(parser) = pool.get(name) {
return Some(Arc::clone(parser));
}
}
// If not in pool, create one and add to pool
let creators = self.creators.read().unwrap();
if let Some(creator) = creators.get(name) {
let parser = Arc::new(Mutex::new(creator()));
// Add to pool for future use
let mut pool = self.pool.write().unwrap();
pool.insert(name.to_string(), Arc::clone(&parser));
Some(parser)
} else {
None
}
}
/// Get parser for a specific model
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
}
// Try prefix matching with more specific patterns first
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
// Return the best matching parser
if let Some((_, parser_name)) = best_match {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
// Fall back to default parser
let default = self.default_parser.read().unwrap().clone();
self.get_pooled_parser(&default)
}
/// Clear the parser pool, forcing new instances to be created.
pub fn clear_pool(&self) {
let mut pool = self.pool.write().unwrap();
pool.clear();
}
/// Set the default parser
pub fn set_default_parser(&self, name: impl Into<String>) {
let mut default = self.default_parser.write().unwrap();
*default = name.into();
}
}
impl Default for ToolParserRegistry {
fn default() -> Self {
Self::new()
}
}
/// Factory for creating tool parsers based on model type.
#[derive(Clone)]
pub struct ToolParserFactory {
registry: ToolParserRegistry,
}
impl ToolParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ToolParserRegistry::new();
// Register default parsers
registry.register_parser("json", || Box::new(JsonParser::new()));
registry.register_parser("mistral", || Box::new(MistralParser::new()));
registry.register_parser("qwen", || Box::new(QwenParser::new()));
registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
registry.register_parser("llama", || Box::new(LlamaParser::new()));
registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
registry.register_parser("glm4_moe", || Box::new(Glm4MoeParser::new()));
registry.register_parser("step3", || Box::new(Step3Parser::new()));
registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
// Register GPT-OSS parsers
registry.register_parser("gpt_oss_legacy", || Box::new(GptOssParser::new()));
registry.register_parser("gpt_oss_harmony", || Box::new(GptOssHarmonyParser::new()));
// Choose which GPT-OSS variant to use as default
if use_harmony_gpt_oss() {
registry.register_parser("gpt_oss", || Box::new(GptOssHarmonyParser::new()));
} else {
registry.register_parser("gpt_oss", || Box::new(GptOssParser::new()));
}
// Register default model mappings
Self::register_default_mappings(&registry);
Self { registry }
}
fn register_default_mappings(registry: &ToolParserRegistry) {
// OpenAI models
registry.map_model("gpt-4*", "json");
registry.map_model("gpt-3.5*", "json");
registry.map_model("gpt-4o*", "json");
// Anthropic models
registry.map_model("claude-*", "json");
// Mistral models
registry.map_model("mistral-*", "mistral");
registry.map_model("mixtral-*", "mistral");
// Qwen models
registry.map_model("qwen*", "qwen");
registry.map_model("Qwen*", "qwen");
// Llama models
registry.map_model("llama-4*", "pythonic");
registry.map_model("meta-llama-4*", "pythonic");
registry.map_model("llama-3.2*", "llama");
registry.map_model("meta-llama-3.2*", "llama");
registry.map_model("llama-*", "json");
registry.map_model("meta-llama-*", "json");
// DeepSeek models
registry.map_model("deepseek-v3*", "deepseek");
registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
registry.map_model("deepseek-*", "pythonic");
// GLM models
registry.map_model("glm-4.5*", "glm4_moe");
registry.map_model("glm-4.6*", "glm4_moe");
registry.map_model("glm-*", "json");
// Step3 models
registry.map_model("step3*", "step3");
registry.map_model("Step-3*", "step3");
// Kimi models
registry.map_model("kimi-k2*", "kimik2");
registry.map_model("Kimi-K2*", "kimik2");
registry.map_model("moonshot*/Kimi-K2*", "kimik2");
// GPT-OSS models
registry.map_model("gpt-oss*", "gpt_oss");
registry.map_model("t4-*", "gpt_oss");
// Other models
registry.map_model("gemini-*", "json");
registry.map_model("palm-*", "json");
registry.map_model("gemma-*", "json");
}
/// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently.
/// Falls back to JSON parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
self.registry
.get_pooled_for_model(model_id)
.unwrap_or_else(|| {
// Fallback to JSON parser
self.registry
.get_pooled_parser("json")
.expect("JSON parser should always be registered")
})
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ToolParserRegistry {
&self.registry
}
/// Clear the parser pool.
pub fn clear_pool(&self) {
self.registry.clear_pool();
}
/// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
/// This is useful for benchmarks and testing where you want independent parser instances.
pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
// Determine which parser type to use
let parser_type = {
let mapping = self.registry.model_mapping.read().unwrap();
// Try exact match first
if let Some(parser_name) = mapping.get(model_id) {
parser_name.clone()
} else {
// Try prefix matching
let best_match = mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*')
&& model_id.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
if let Some((_, parser_name)) = best_match {
parser_name.clone()
} else {
// Fall back to default
self.registry.default_parser.read().unwrap().clone()
}
}
};
let creators = self.registry.creators.read().unwrap();
creators.get(&parser_type).map(|creator| {
// Call the creator to get a Box<dyn ToolParser>, then convert to Arc
let boxed_parser = creator();
Arc::from(boxed_parser)
})
}
/// List all registered parsers (for compatibility with old API).
pub fn list_parsers(&self) -> Vec<String> {
self.registry
.creators
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
}
impl Default for ToolParserFactory {
fn default() -> Self {
Self::new()
}
}
fn use_harmony_gpt_oss() -> bool {
std::env::var("ROUTER_USE_HARMONY_GPT_OSS")
.ok()
.map(|value| {
let normalized = value.trim();
matches!(
normalized,
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
)
})
.unwrap_or(false)
}