[router][tool call] Full support for ToolChoice (#11085)
Co-authored-by: Simo Lin <linsimo.mark@gmail.com>
This commit is contained in:
@@ -1491,6 +1491,7 @@ impl ResponsesResponse {
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
|
||||
ToolChoice::Function { .. } => "function".to_string(),
|
||||
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
|
||||
},
|
||||
tools: request.tools.clone(),
|
||||
top_p: request.top_p,
|
||||
@@ -1718,6 +1719,12 @@ pub enum ToolChoice {
|
||||
tool_type: String, // "function"
|
||||
function: FunctionChoice,
|
||||
},
|
||||
AllowedTools {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // "allowed_tools"
|
||||
mode: String, // "auto" | "required" TODO: need validation
|
||||
tools: Vec<ToolReference>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ToolChoice {
|
||||
@@ -1732,6 +1739,14 @@ pub struct FunctionChoice {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Tool reference for ToolChoice::AllowedTools
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolReference {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -20,8 +21,9 @@ use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::ChatMessage;
|
||||
use crate::protocols::spec::{
|
||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
|
||||
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
|
||||
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
|
||||
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
|
||||
ToolChoiceValue, Usage,
|
||||
};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
@@ -34,7 +36,7 @@ use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tokenizer::HuggingFaceTokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
use tokio_stream::StreamExt;
|
||||
use uuid::Uuid;
|
||||
@@ -132,8 +134,39 @@ impl GrpcRouter {
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
// Step 3: Process messages and apply chat template
|
||||
let processed_messages = match self.process_chat_messages(body) {
|
||||
// Step 3: Filter tools if needed for allowed_tools or specific function
|
||||
// Only clone body if we need to modify tools
|
||||
let mut body_with_filtered_tools;
|
||||
let body_ref = match &body.tool_choice {
|
||||
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
|
||||
body_with_filtered_tools = body.clone();
|
||||
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
|
||||
let allowed_names: std::collections::HashSet<&str> =
|
||||
allowed.iter().map(|t| t.name.as_str()).collect();
|
||||
let filtered_tools: Vec<Tool> = all_tools
|
||||
.iter()
|
||||
.filter(|t| allowed_names.contains(t.function.name.as_str()))
|
||||
.cloned()
|
||||
.collect();
|
||||
body_with_filtered_tools.tools = Some(filtered_tools);
|
||||
&body_with_filtered_tools
|
||||
}
|
||||
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
|
||||
body_with_filtered_tools = body.clone();
|
||||
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
|
||||
let filtered_tools: Vec<Tool> = all_tools
|
||||
.iter()
|
||||
.filter(|t| t.function.name == function.name)
|
||||
.cloned()
|
||||
.collect();
|
||||
body_with_filtered_tools.tools = Some(filtered_tools);
|
||||
&body_with_filtered_tools
|
||||
}
|
||||
_ => body, // No filtering needed, use original
|
||||
};
|
||||
|
||||
// Step 4: Process messages and apply chat template
|
||||
let processed_messages = match self.process_chat_messages(body_ref) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
error!("Failed to process chat messages: {}", e);
|
||||
@@ -141,7 +174,7 @@ impl GrpcRouter {
|
||||
}
|
||||
};
|
||||
|
||||
// Step 4: Tokenize the processed text
|
||||
// Step 5: Tokenize the processed text
|
||||
let encoding = match self.tokenizer.encode(&processed_messages.text) {
|
||||
Ok(encoding) => encoding,
|
||||
Err(e) => {
|
||||
@@ -157,18 +190,17 @@ impl GrpcRouter {
|
||||
let token_ids = encoding.token_ids().to_vec();
|
||||
debug!("Tokenized {} tokens from input", token_ids.len());
|
||||
|
||||
// Step 5: Build tool constraints if needed
|
||||
let tool_call_constraint = if let Some(tools) = &body.tools {
|
||||
// Step 6: Build tool constraints if needed
|
||||
// body_ref already has filtered tools if needed
|
||||
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
|
||||
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
});
|
||||
|
||||
// Step 6: Build the base gRPC request
|
||||
// Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable)
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let request = match client.build_generate_request(
|
||||
request_id,
|
||||
body,
|
||||
body_ref,
|
||||
processed_messages.text.clone(),
|
||||
token_ids,
|
||||
processed_messages.multimodal_inputs,
|
||||
@@ -561,17 +593,228 @@ impl GrpcRouter {
|
||||
}
|
||||
|
||||
/// Generate tool constraints for structured generation
|
||||
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
|
||||
fn generate_tool_constraints(
|
||||
&self,
|
||||
_tools: &[Tool],
|
||||
_tool_choice: &Option<ToolChoice>,
|
||||
model: &str,
|
||||
tools: &[Tool],
|
||||
tool_choice: &Option<ToolChoice>,
|
||||
_model: &str,
|
||||
) -> Option<(String, String)> {
|
||||
let _parser = self.tool_parser_registry.get_parser(model)?;
|
||||
// TODO: Implement actual constraint generation logic
|
||||
// For now, return None as this is placeholder implementation
|
||||
let choice = tool_choice.as_ref()?;
|
||||
|
||||
match choice {
|
||||
// Specific function: Return parameters schema directly
|
||||
// tools should already be filtered to contain only the specific function
|
||||
ToolChoice::Function { .. } => {
|
||||
if tools.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let tool = &tools[0];
|
||||
|
||||
// Return the tool's parameters schema directly (not wrapped in array)
|
||||
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
|
||||
Some(("json_schema".to_string(), params_schema))
|
||||
}
|
||||
|
||||
// Required: Array of tool calls with minItems: 1
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => {
|
||||
let schema = self.build_required_array_schema(tools)?;
|
||||
Some(("json_schema".to_string(), schema))
|
||||
}
|
||||
|
||||
// AllowedTools with required mode: tools are already filtered
|
||||
ToolChoice::AllowedTools { mode, .. } => {
|
||||
if mode == "required" {
|
||||
if tools.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let schema = self.build_required_array_schema(tools)?;
|
||||
Some(("json_schema".to_string(), schema))
|
||||
} else {
|
||||
// "auto" mode - no constraint needed
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// "auto" or "none" - no constraint
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build JSON schema for required tool calls (array with minItems: 1)
|
||||
/// Includes $defs consolidation from all tools (matching Python's behavior)
|
||||
fn build_required_array_schema(&self, tools: &[Tool]) -> Option<String> {
|
||||
// Build anyOf schemas for each tool
|
||||
let mut any_of_schemas = Vec::new();
|
||||
for tool in tools {
|
||||
let tool_schema = json!({
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"enum": [tool.function.name]
|
||||
},
|
||||
"parameters": tool.function.parameters
|
||||
},
|
||||
"required": ["name", "parameters"]
|
||||
});
|
||||
any_of_schemas.push(tool_schema);
|
||||
}
|
||||
|
||||
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
|
||||
let mut all_defs: HashMap<String, Value> = HashMap::new();
|
||||
for tool in tools {
|
||||
if let Value::Object(params) = &tool.function.parameters {
|
||||
if let Some(Value::Object(defs)) = params.get("$defs") {
|
||||
for (def_name, def_schema) in defs {
|
||||
if let Some(existing) = all_defs.get(def_name) {
|
||||
// Check for conflicts
|
||||
if existing != def_schema {
|
||||
error!(
|
||||
"Tool definition '{}' has multiple schemas, which is not supported",
|
||||
def_name
|
||||
);
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
all_defs.insert(def_name.clone(), def_schema.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the full array schema
|
||||
let mut array_schema = json!({
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": any_of_schemas
|
||||
}
|
||||
});
|
||||
|
||||
// Add $defs if any were found (matching Python's behavior)
|
||||
if !all_defs.is_empty() {
|
||||
if let Value::Object(ref mut schema_obj) = array_schema {
|
||||
let defs_value =
|
||||
Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
|
||||
schema_obj.insert("$defs".to_string(), defs_value);
|
||||
}
|
||||
}
|
||||
|
||||
serde_json::to_string(&array_schema).ok()
|
||||
}
|
||||
|
||||
/// Parse tool calls from JSON schema constrained response
|
||||
fn parse_json_schema_response(
|
||||
&self,
|
||||
processed_text: &str,
|
||||
tool_choice: &Option<ToolChoice>,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
match tool_choice {
|
||||
Some(ToolChoice::Function { function, .. }) => {
|
||||
// Specific function: Parse parameters directly
|
||||
match serde_json::from_str::<Value>(processed_text) {
|
||||
Ok(params) => {
|
||||
let tool_call = ToolCall {
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: function.name.clone(),
|
||||
arguments: Some(
|
||||
serde_json::to_string(¶ms)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
};
|
||||
(Some(vec![tool_call]), String::new())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse specific function parameters: {}", e);
|
||||
(None, processed_text.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Required))
|
||||
| Some(ToolChoice::AllowedTools { .. }) => {
|
||||
// Required mode: Parse array of tool calls
|
||||
match serde_json::from_str::<Vec<Value>>(processed_text) {
|
||||
Ok(parsed_array) => {
|
||||
let spec_tool_calls: Vec<ToolCall> = parsed_array
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, item)| {
|
||||
let obj = item.as_object()?;
|
||||
let name = obj.get("name")?.as_str()?.to_string();
|
||||
let parameters = obj.get("parameters")?;
|
||||
|
||||
Some(ToolCall {
|
||||
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(parameters)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
(Some(spec_tool_calls), String::new())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse required tool call array: {}", e);
|
||||
(None, processed_text.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (None, processed_text.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse tool calls using model-specific parser
|
||||
async fn parse_with_model_parser(
|
||||
&self,
|
||||
processed_text: &str,
|
||||
model: &str,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
let Some(parser) = self.tool_parser_registry.get_parser(model) else {
|
||||
return (None, processed_text.to_string());
|
||||
};
|
||||
|
||||
if !parser.detect_format(processed_text) {
|
||||
return (None, processed_text.to_string());
|
||||
}
|
||||
|
||||
match parser.parse_complete(processed_text).await {
|
||||
Ok((normal_text, parsed_tool_calls)) => {
|
||||
if parsed_tool_calls.is_empty() {
|
||||
return (None, normal_text);
|
||||
}
|
||||
|
||||
let spec_tool_calls = parsed_tool_calls
|
||||
.into_iter()
|
||||
.map(|tc| ToolCall {
|
||||
id: tc.id,
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: tc.function.name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
(Some(spec_tool_calls), normal_text)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Tool call parsing error: {}", e);
|
||||
(None, processed_text.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the generate input into optional original text and token IDs
|
||||
fn resolve_generate_input(
|
||||
@@ -1130,36 +1373,21 @@ impl GrpcRouter {
|
||||
);
|
||||
|
||||
if tool_choice_enabled && original_request.tools.is_some() {
|
||||
if let Some(parser) = self
|
||||
.tool_parser_registry
|
||||
.get_parser(&original_request.model)
|
||||
{
|
||||
match parser.parse_complete(&processed_text).await {
|
||||
Ok((normal_text, parsed_tool_calls)) => {
|
||||
if !parsed_tool_calls.is_empty() {
|
||||
let spec_tool_calls = parsed_tool_calls
|
||||
.into_iter()
|
||||
.map(|tc| crate::protocols::spec::ToolCall {
|
||||
id: tc.id,
|
||||
tool_type: "function".to_string(),
|
||||
function: crate::protocols::spec::FunctionCallResponse {
|
||||
name: tc.function.name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
tool_calls = Some(spec_tool_calls);
|
||||
processed_text = normal_text;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Tool call parsing error: {}", e);
|
||||
// Continue without tool calls rather than failing
|
||||
}
|
||||
}
|
||||
// Check if JSON schema constraint was used (specific function or required mode)
|
||||
let used_json_schema = match &original_request.tool_choice {
|
||||
Some(ToolChoice::Function { .. }) => true,
|
||||
Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true,
|
||||
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if used_json_schema {
|
||||
(tool_calls, processed_text) =
|
||||
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
|
||||
} else {
|
||||
(tool_calls, processed_text) = self
|
||||
.parse_with_model_parser(&processed_text, &original_request.model)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user