diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index a441bbf7f..086d152df 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1491,6 +1491,7 @@ impl ResponsesResponse { ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(), ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(), ToolChoice::Function { .. } => "function".to_string(), + ToolChoice::AllowedTools { mode, .. } => mode.clone(), }, tools: request.tools.clone(), top_p: request.top_p, @@ -1718,6 +1719,12 @@ pub enum ToolChoice { tool_type: String, // "function" function: FunctionChoice, }, + AllowedTools { + #[serde(rename = "type")] + tool_type: String, // "allowed_tools" + mode: String, // "auto" | "required" TODO: need validation + tools: Vec, + }, } impl Default for ToolChoice { @@ -1732,6 +1739,14 @@ pub struct FunctionChoice { pub name: String, } +/// Tool reference for ToolChoice::AllowedTools +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolReference { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub name: String, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Tool { #[serde(rename = "type")] diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 483a8127f..7deec2d11 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -1,5 +1,6 @@ // gRPC Router Implementation +use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; @@ -20,8 +21,9 @@ use crate::policies::PolicyRegistry; use crate::protocols::spec::ChatMessage; use crate::protocols::spec::{ ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, - ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage, + CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest, + ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice, + ToolChoiceValue, Usage, }; use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; @@ -34,7 +36,7 @@ use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::HuggingFaceTokenizer; use crate::tool_parser::ParserRegistry; use proto::generate_response::Response::{Chunk, Complete, Error}; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use std::time::{Instant, SystemTime, UNIX_EPOCH}; use tokio_stream::StreamExt; use uuid::Uuid; @@ -132,8 +134,39 @@ impl GrpcRouter { Err(response) => return response, }; - // Step 3: Process messages and apply chat template - let processed_messages = match self.process_chat_messages(body) { + // Step 3: Filter tools if needed for allowed_tools or specific function + // Only clone body if we need to modify tools + let mut body_with_filtered_tools; + let body_ref = match &body.tool_choice { + Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => { + body_with_filtered_tools = body.clone(); + let all_tools = body_with_filtered_tools.tools.as_ref().unwrap(); + let allowed_names: std::collections::HashSet<&str> = + allowed.iter().map(|t| t.name.as_str()).collect(); + let filtered_tools: Vec = all_tools + .iter() + .filter(|t| allowed_names.contains(t.function.name.as_str())) + .cloned() + .collect(); + body_with_filtered_tools.tools = Some(filtered_tools); + &body_with_filtered_tools + } + Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => { + body_with_filtered_tools = body.clone(); + let all_tools = body_with_filtered_tools.tools.as_ref().unwrap(); + let filtered_tools: Vec = all_tools + .iter() + .filter(|t| t.function.name == function.name) + .cloned() + .collect(); + body_with_filtered_tools.tools = Some(filtered_tools); + &body_with_filtered_tools + } + _ => body, // No filtering needed, use original + }; + + // Step 4: Process messages and apply chat template + let processed_messages = match self.process_chat_messages(body_ref) { Ok(msgs) => msgs, Err(e) => { error!("Failed to process chat messages: {}", e); @@ -141,7 +174,7 @@ impl GrpcRouter { } }; - // Step 4: Tokenize the processed text + // Step 5: Tokenize the processed text let encoding = match self.tokenizer.encode(&processed_messages.text) { Ok(encoding) => encoding, Err(e) => { @@ -157,18 +190,17 @@ impl GrpcRouter { let token_ids = encoding.token_ids().to_vec(); debug!("Tokenized {} tokens from input", token_ids.len()); - // Step 5: Build tool constraints if needed - let tool_call_constraint = if let Some(tools) = &body.tools { + // Step 6: Build tool constraints if needed + // body_ref already has filtered tools if needed + let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { self.generate_tool_constraints(tools, &body.tool_choice, &body.model) - } else { - None - }; + }); - // Step 6: Build the base gRPC request + // Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable) let request_id = format!("chatcmpl-{}", Uuid::new_v4()); let request = match client.build_generate_request( request_id, - body, + body_ref, processed_messages.text.clone(), token_ids, processed_messages.multimodal_inputs, @@ -561,16 +593,227 @@ impl GrpcRouter { } /// Generate tool constraints for structured generation + /// Note: tools should already be filtered if needed (by allowed_tools or specific function) fn generate_tool_constraints( &self, - _tools: &[Tool], - _tool_choice: &Option, - model: &str, + tools: &[Tool], + tool_choice: &Option, + _model: &str, ) -> Option<(String, String)> { - let _parser = self.tool_parser_registry.get_parser(model)?; - // TODO: Implement actual constraint generation logic - // For now, return None as this is placeholder implementation - None + let choice = tool_choice.as_ref()?; + + match choice { + // Specific function: Return parameters schema directly + // tools should already be filtered to contain only the specific function + ToolChoice::Function { .. } => { + if tools.is_empty() { + return None; + } + let tool = &tools[0]; + + // Return the tool's parameters schema directly (not wrapped in array) + let params_schema = serde_json::to_string(&tool.function.parameters).ok()?; + Some(("json_schema".to_string(), params_schema)) + } + + // Required: Array of tool calls with minItems: 1 + ToolChoice::Value(ToolChoiceValue::Required) => { + let schema = self.build_required_array_schema(tools)?; + Some(("json_schema".to_string(), schema)) + } + + // AllowedTools with required mode: tools are already filtered + ToolChoice::AllowedTools { mode, .. } => { + if mode == "required" { + if tools.is_empty() { + return None; + } + let schema = self.build_required_array_schema(tools)?; + Some(("json_schema".to_string(), schema)) + } else { + // "auto" mode - no constraint needed + None + } + } + + // "auto" or "none" - no constraint + _ => None, + } + } + + /// Build JSON schema for required tool calls (array with minItems: 1) + /// Includes $defs consolidation from all tools (matching Python's behavior) + fn build_required_array_schema(&self, tools: &[Tool]) -> Option { + // Build anyOf schemas for each tool + let mut any_of_schemas = Vec::new(); + for tool in tools { + let tool_schema = json!({ + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name] + }, + "parameters": tool.function.parameters + }, + "required": ["name", "parameters"] + }); + any_of_schemas.push(tool_schema); + } + + // Consolidate $defs from all tools (matching Python's _get_tool_schema_defs) + let mut all_defs: HashMap = HashMap::new(); + for tool in tools { + if let Value::Object(params) = &tool.function.parameters { + if let Some(Value::Object(defs)) = params.get("$defs") { + for (def_name, def_schema) in defs { + if let Some(existing) = all_defs.get(def_name) { + // Check for conflicts + if existing != def_schema { + error!( + "Tool definition '{}' has multiple schemas, which is not supported", + def_name + ); + return None; + } + } else { + all_defs.insert(def_name.clone(), def_schema.clone()); + } + } + } + } + } + + // Build the full array schema + let mut array_schema = json!({ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": any_of_schemas + } + }); + + // Add $defs if any were found (matching Python's behavior) + if !all_defs.is_empty() { + if let Value::Object(ref mut schema_obj) = array_schema { + let defs_value = + Value::Object(all_defs.into_iter().collect::>()); + schema_obj.insert("$defs".to_string(), defs_value); + } + } + + serde_json::to_string(&array_schema).ok() + } + + /// Parse tool calls from JSON schema constrained response + fn parse_json_schema_response( + &self, + processed_text: &str, + tool_choice: &Option, + ) -> (Option>, String) { + match tool_choice { + Some(ToolChoice::Function { function, .. }) => { + // Specific function: Parse parameters directly + match serde_json::from_str::(processed_text) { + Ok(params) => { + let tool_call = ToolCall { + id: format!("call_{}", uuid::Uuid::new_v4()), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: function.name.clone(), + arguments: Some( + serde_json::to_string(¶ms) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + }; + (Some(vec![tool_call]), String::new()) + } + Err(e) => { + error!("Failed to parse specific function parameters: {}", e); + (None, processed_text.to_string()) + } + } + } + Some(ToolChoice::Value(ToolChoiceValue::Required)) + | Some(ToolChoice::AllowedTools { .. }) => { + // Required mode: Parse array of tool calls + match serde_json::from_str::>(processed_text) { + Ok(parsed_array) => { + let spec_tool_calls: Vec = parsed_array + .into_iter() + .enumerate() + .filter_map(|(i, item)| { + let obj = item.as_object()?; + let name = obj.get("name")?.as_str()?.to_string(); + let parameters = obj.get("parameters")?; + + Some(ToolCall { + id: format!("call_{}_{}", i, uuid::Uuid::new_v4()), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name, + arguments: Some( + serde_json::to_string(parameters) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + }) + }) + .collect(); + (Some(spec_tool_calls), String::new()) + } + Err(e) => { + error!("Failed to parse required tool call array: {}", e); + (None, processed_text.to_string()) + } + } + } + _ => (None, processed_text.to_string()), + } + } + + /// Parse tool calls using model-specific parser + async fn parse_with_model_parser( + &self, + processed_text: &str, + model: &str, + ) -> (Option>, String) { + let Some(parser) = self.tool_parser_registry.get_parser(model) else { + return (None, processed_text.to_string()); + }; + + if !parser.detect_format(processed_text) { + return (None, processed_text.to_string()); + } + + match parser.parse_complete(processed_text).await { + Ok((normal_text, parsed_tool_calls)) => { + if parsed_tool_calls.is_empty() { + return (None, normal_text); + } + + let spec_tool_calls = parsed_tool_calls + .into_iter() + .map(|tc| ToolCall { + id: tc.id, + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: tc.function.name, + arguments: Some( + serde_json::to_string(&tc.function.arguments) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + }) + .collect(); + (Some(spec_tool_calls), normal_text) + } + Err(e) => { + error!("Tool call parsing error: {}", e); + (None, processed_text.to_string()) + } + } } /// Resolve the generate input into optional original text and token IDs @@ -1130,36 +1373,21 @@ impl GrpcRouter { ); if tool_choice_enabled && original_request.tools.is_some() { - if let Some(parser) = self - .tool_parser_registry - .get_parser(&original_request.model) - { - match parser.parse_complete(&processed_text).await { - Ok((normal_text, parsed_tool_calls)) => { - if !parsed_tool_calls.is_empty() { - let spec_tool_calls = parsed_tool_calls - .into_iter() - .map(|tc| crate::protocols::spec::ToolCall { - id: tc.id, - tool_type: "function".to_string(), - function: crate::protocols::spec::FunctionCallResponse { - name: tc.function.name, - arguments: Some( - serde_json::to_string(&tc.function.arguments) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, - }) - .collect(); - tool_calls = Some(spec_tool_calls); - processed_text = normal_text; - } - } - Err(e) => { - error!("Tool call parsing error: {}", e); - // Continue without tool calls rather than failing - } - } + // Check if JSON schema constraint was used (specific function or required mode) + let used_json_schema = match &original_request.tool_choice { + Some(ToolChoice::Function { .. }) => true, + Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true, + Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", + _ => false, + }; + + if used_json_schema { + (tool_calls, processed_text) = + self.parse_json_schema_response(&processed_text, &original_request.tool_choice); + } else { + (tool_calls, processed_text) = self + .parse_with_model_parser(&processed_text, &original_request.model) + .await; } }