120 lines
3.6 KiB
Rust
120 lines
3.6 KiB
Rust
use crate::tool_parser::traits::ToolParser;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
/// Registry for tool parsers and model mappings
|
|
pub struct ParserRegistry {
|
|
/// Map of parser name to parser instance
|
|
parsers: HashMap<String, Arc<dyn ToolParser>>,
|
|
/// Map of model name/pattern to parser name
|
|
model_mapping: HashMap<String, String>,
|
|
/// Default parser to use when no match found
|
|
default_parser: String,
|
|
}
|
|
|
|
impl ParserRegistry {
|
|
/// Create a new parser registry with default mappings
|
|
pub fn new() -> Self {
|
|
let mut registry = Self {
|
|
parsers: HashMap::new(),
|
|
model_mapping: HashMap::new(),
|
|
default_parser: "json".to_string(),
|
|
};
|
|
|
|
// Register default model mappings
|
|
registry.register_default_mappings();
|
|
|
|
registry
|
|
}
|
|
|
|
/// Register a parser
|
|
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
|
|
self.parsers.insert(name.into(), parser);
|
|
}
|
|
|
|
/// Map a model name/pattern to a parser
|
|
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
|
|
self.model_mapping.insert(model.into(), parser.into());
|
|
}
|
|
|
|
/// Get parser for a specific model
|
|
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
|
|
// Try exact match first
|
|
if let Some(parser_name) = self.model_mapping.get(model) {
|
|
if let Some(parser) = self.parsers.get(parser_name) {
|
|
return Some(parser.clone());
|
|
}
|
|
}
|
|
|
|
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
|
|
for (pattern, parser_name) in &self.model_mapping {
|
|
if pattern.ends_with('*') {
|
|
let prefix = &pattern[..pattern.len() - 1];
|
|
if model.starts_with(prefix) {
|
|
if let Some(parser) = self.parsers.get(parser_name) {
|
|
return Some(parser.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall back to default parser if it exists
|
|
self.parsers.get(&self.default_parser).cloned()
|
|
}
|
|
|
|
/// List all registered parsers
|
|
pub fn list_parsers(&self) -> Vec<&str> {
|
|
self.parsers.keys().map(|s| s.as_str()).collect()
|
|
}
|
|
|
|
/// List all model mappings
|
|
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
|
|
self.model_mapping
|
|
.iter()
|
|
.map(|(k, v)| (k.as_str(), v.as_str()))
|
|
.collect()
|
|
}
|
|
|
|
/// Register default model mappings
|
|
fn register_default_mappings(&mut self) {
|
|
// OpenAI models
|
|
self.map_model("gpt-4*", "json");
|
|
self.map_model("gpt-3.5*", "json");
|
|
self.map_model("gpt-4o*", "json");
|
|
|
|
// Anthropic models
|
|
self.map_model("claude-*", "json");
|
|
|
|
// Mistral models
|
|
self.map_model("mistral-*", "mistral");
|
|
self.map_model("mixtral-*", "mistral");
|
|
|
|
// Qwen models
|
|
self.map_model("qwen*", "qwen");
|
|
|
|
// Llama models
|
|
self.map_model("llama-*", "llama");
|
|
self.map_model("meta-llama-*", "llama");
|
|
|
|
// Other models default to JSON
|
|
self.map_model("gemini-*", "json");
|
|
self.map_model("palm-*", "json");
|
|
}
|
|
|
|
/// Set the default parser
|
|
pub fn set_default_parser(&mut self, name: impl Into<String>) {
|
|
self.default_parser = name.into();
|
|
}
|
|
|
|
/// Check if a parser is registered
|
|
pub fn has_parser(&self, name: &str) -> bool {
|
|
self.parsers.contains_key(name)
|
|
}
|
|
}
|
|
|
|
impl Default for ParserRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|