[router][grpc] Fix streaming bugs: empty tool names, state pollution, and panics (#11373)
This commit is contained in:
@@ -11,25 +11,25 @@ use crate::tool_parser::parsers::{
|
||||
use crate::tool_parser::traits::ToolParser;
|
||||
|
||||
/// Type alias for pooled parser instances.
|
||||
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
|
||||
pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
|
||||
|
||||
/// Type alias for parser creator functions.
|
||||
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
|
||||
|
||||
/// Registry for model-specific tool parsers with pooling support.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolParserRegistry {
|
||||
pub struct ParserRegistry {
|
||||
/// Creator functions for parsers (used when pool is empty)
|
||||
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
|
||||
/// Pooled parser instances for reuse
|
||||
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
|
||||
pool: Arc<RwLock<HashMap<String, PooledParser>>>,
|
||||
/// Model pattern to parser name mappings
|
||||
model_mapping: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Default parser name
|
||||
default_parser: Arc<RwLock<String>>,
|
||||
}
|
||||
|
||||
impl ToolParserRegistry {
|
||||
impl ParserRegistry {
|
||||
/// Create a new empty registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -57,7 +57,7 @@ impl ToolParserRegistry {
|
||||
|
||||
/// Get a pooled parser by exact name.
|
||||
/// Returns a shared parser instance from the pool, creating one if needed.
|
||||
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
|
||||
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
|
||||
// First check if we have a pooled instance
|
||||
{
|
||||
let pool = self.pool.read().unwrap();
|
||||
@@ -81,8 +81,91 @@ impl ToolParserRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a parser with the given name is registered.
|
||||
pub fn has_parser(&self, name: &str) -> bool {
|
||||
let creators = self.creators.read().unwrap();
|
||||
creators.contains_key(name)
|
||||
}
|
||||
|
||||
/// Create a fresh (non-pooled) parser instance by exact name.
|
||||
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
|
||||
pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
|
||||
let creators = self.creators.read().unwrap();
|
||||
creators.get(name).map(|creator| creator())
|
||||
}
|
||||
|
||||
/// Check if a parser can be created for a specific model without actually creating it.
|
||||
/// Returns true if a parser is available (registered) for this model.
|
||||
pub fn has_parser_for_model(&self, model: &str) -> bool {
|
||||
// Try exact match first
|
||||
{
|
||||
let mapping = self.model_mapping.read().unwrap();
|
||||
if let Some(parser_name) = mapping.get(model) {
|
||||
let creators = self.creators.read().unwrap();
|
||||
if creators.contains_key(parser_name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching
|
||||
let model_mapping = self.model_mapping.read().unwrap();
|
||||
let best_match = model_mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
|
||||
})
|
||||
.max_by_key(|(pattern, _)| pattern.len());
|
||||
|
||||
if let Some((_, parser_name)) = best_match {
|
||||
let creators = self.creators.read().unwrap();
|
||||
if creators.contains_key(parser_name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if default parser exists
|
||||
let default = self.default_parser.read().unwrap().clone();
|
||||
let creators = self.creators.read().unwrap();
|
||||
creators.contains_key(&default)
|
||||
}
|
||||
|
||||
/// Create a fresh (non-pooled) parser instance for a specific model.
|
||||
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
|
||||
pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
|
||||
// Try exact match first
|
||||
{
|
||||
let mapping = self.model_mapping.read().unwrap();
|
||||
if let Some(parser_name) = mapping.get(model) {
|
||||
if let Some(parser) = self.create_parser(parser_name) {
|
||||
return Some(parser);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try prefix matching with more specific patterns first
|
||||
let model_mapping = self.model_mapping.read().unwrap();
|
||||
let best_match = model_mapping
|
||||
.iter()
|
||||
.filter(|(pattern, _)| {
|
||||
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
|
||||
})
|
||||
.max_by_key(|(pattern, _)| pattern.len());
|
||||
|
||||
// Return the best matching parser
|
||||
if let Some((_, parser_name)) = best_match {
|
||||
if let Some(parser) = self.create_parser(parser_name) {
|
||||
return Some(parser);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default parser
|
||||
let default = self.default_parser.read().unwrap().clone();
|
||||
self.create_parser(&default)
|
||||
}
|
||||
|
||||
/// Get parser for a specific model
|
||||
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
|
||||
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
|
||||
// Try exact match first
|
||||
{
|
||||
let mapping = self.model_mapping.read().unwrap();
|
||||
@@ -127,7 +210,7 @@ impl ToolParserRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolParserRegistry {
|
||||
impl Default for ParserRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
@@ -135,14 +218,14 @@ impl Default for ToolParserRegistry {
|
||||
|
||||
/// Factory for creating tool parsers based on model type.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolParserFactory {
|
||||
registry: ToolParserRegistry,
|
||||
pub struct ParserFactory {
|
||||
registry: ParserRegistry,
|
||||
}
|
||||
|
||||
impl ToolParserFactory {
|
||||
impl ParserFactory {
|
||||
/// Create a new factory with default parsers registered.
|
||||
pub fn new() -> Self {
|
||||
let registry = ToolParserRegistry::new();
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Register default parsers
|
||||
registry.register_parser("json", || Box::new(JsonParser::new()));
|
||||
@@ -172,7 +255,7 @@ impl ToolParserFactory {
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
fn register_default_mappings(registry: &ToolParserRegistry) {
|
||||
fn register_default_mappings(registry: &ParserRegistry) {
|
||||
// OpenAI models
|
||||
registry.map_model("gpt-4*", "json");
|
||||
registry.map_model("gpt-3.5*", "json");
|
||||
@@ -229,7 +312,7 @@ impl ToolParserFactory {
|
||||
/// Get a pooled parser for the given model ID.
|
||||
/// Returns a shared instance that can be used concurrently.
|
||||
/// Falls back to JSON parser if model is not recognized.
|
||||
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
|
||||
pub fn get_pooled(&self, model_id: &str) -> PooledParser {
|
||||
self.registry
|
||||
.get_pooled_for_model(model_id)
|
||||
.unwrap_or_else(|| {
|
||||
@@ -241,7 +324,7 @@ impl ToolParserFactory {
|
||||
}
|
||||
|
||||
/// Get the internal registry for custom registration.
|
||||
pub fn registry(&self) -> &ToolParserRegistry {
|
||||
pub fn registry(&self) -> &ParserRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
@@ -299,7 +382,7 @@ impl ToolParserFactory {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolParserFactory {
|
||||
impl Default for ParserFactory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user