[router][tool call] Full support for ToolChoice (#11085)
Co-authored-by: Simo Lin <linsimo.mark@gmail.com>
This commit is contained in:
@@ -1491,6 +1491,7 @@ impl ResponsesResponse {
|
|||||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||||
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
|
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
|
||||||
ToolChoice::Function { .. } => "function".to_string(),
|
ToolChoice::Function { .. } => "function".to_string(),
|
||||||
|
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
|
||||||
},
|
},
|
||||||
tools: request.tools.clone(),
|
tools: request.tools.clone(),
|
||||||
top_p: request.top_p,
|
top_p: request.top_p,
|
||||||
@@ -1718,6 +1719,12 @@ pub enum ToolChoice {
|
|||||||
tool_type: String, // "function"
|
tool_type: String, // "function"
|
||||||
function: FunctionChoice,
|
function: FunctionChoice,
|
||||||
},
|
},
|
||||||
|
AllowedTools {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
tool_type: String, // "allowed_tools"
|
||||||
|
mode: String, // "auto" | "required" TODO: need validation
|
||||||
|
tools: Vec<ToolReference>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ToolChoice {
|
impl Default for ToolChoice {
|
||||||
@@ -1732,6 +1739,14 @@ pub struct FunctionChoice {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tool reference for ToolChoice::AllowedTools
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ToolReference {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String, // "function"
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Tool {
|
pub struct Tool {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// gRPC Router Implementation
|
// gRPC Router Implementation
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -20,8 +21,9 @@ use crate::policies::PolicyRegistry;
|
|||||||
use crate::protocols::spec::ChatMessage;
|
use crate::protocols::spec::ChatMessage;
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
|
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
|
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
|
||||||
|
ToolChoiceValue, Usage,
|
||||||
};
|
};
|
||||||
use crate::reasoning_parser::ParserFactory;
|
use crate::reasoning_parser::ParserFactory;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
@@ -34,7 +36,7 @@ use crate::tokenizer::traits::Tokenizer;
|
|||||||
use crate::tokenizer::HuggingFaceTokenizer;
|
use crate::tokenizer::HuggingFaceTokenizer;
|
||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -132,8 +134,39 @@ impl GrpcRouter {
|
|||||||
Err(response) => return response,
|
Err(response) => return response,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 3: Process messages and apply chat template
|
// Step 3: Filter tools if needed for allowed_tools or specific function
|
||||||
let processed_messages = match self.process_chat_messages(body) {
|
// Only clone body if we need to modify tools
|
||||||
|
let mut body_with_filtered_tools;
|
||||||
|
let body_ref = match &body.tool_choice {
|
||||||
|
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
|
||||||
|
body_with_filtered_tools = body.clone();
|
||||||
|
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
|
||||||
|
let allowed_names: std::collections::HashSet<&str> =
|
||||||
|
allowed.iter().map(|t| t.name.as_str()).collect();
|
||||||
|
let filtered_tools: Vec<Tool> = all_tools
|
||||||
|
.iter()
|
||||||
|
.filter(|t| allowed_names.contains(t.function.name.as_str()))
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
body_with_filtered_tools.tools = Some(filtered_tools);
|
||||||
|
&body_with_filtered_tools
|
||||||
|
}
|
||||||
|
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
|
||||||
|
body_with_filtered_tools = body.clone();
|
||||||
|
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
|
||||||
|
let filtered_tools: Vec<Tool> = all_tools
|
||||||
|
.iter()
|
||||||
|
.filter(|t| t.function.name == function.name)
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
body_with_filtered_tools.tools = Some(filtered_tools);
|
||||||
|
&body_with_filtered_tools
|
||||||
|
}
|
||||||
|
_ => body, // No filtering needed, use original
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 4: Process messages and apply chat template
|
||||||
|
let processed_messages = match self.process_chat_messages(body_ref) {
|
||||||
Ok(msgs) => msgs,
|
Ok(msgs) => msgs,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to process chat messages: {}", e);
|
error!("Failed to process chat messages: {}", e);
|
||||||
@@ -141,7 +174,7 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 4: Tokenize the processed text
|
// Step 5: Tokenize the processed text
|
||||||
let encoding = match self.tokenizer.encode(&processed_messages.text) {
|
let encoding = match self.tokenizer.encode(&processed_messages.text) {
|
||||||
Ok(encoding) => encoding,
|
Ok(encoding) => encoding,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -157,18 +190,17 @@ impl GrpcRouter {
|
|||||||
let token_ids = encoding.token_ids().to_vec();
|
let token_ids = encoding.token_ids().to_vec();
|
||||||
debug!("Tokenized {} tokens from input", token_ids.len());
|
debug!("Tokenized {} tokens from input", token_ids.len());
|
||||||
|
|
||||||
// Step 5: Build tool constraints if needed
|
// Step 6: Build tool constraints if needed
|
||||||
let tool_call_constraint = if let Some(tools) = &body.tools {
|
// body_ref already has filtered tools if needed
|
||||||
|
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
|
||||||
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
|
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
|
||||||
} else {
|
});
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Step 6: Build the base gRPC request
|
// Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable)
|
||||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||||
let request = match client.build_generate_request(
|
let request = match client.build_generate_request(
|
||||||
request_id,
|
request_id,
|
||||||
body,
|
body_ref,
|
||||||
processed_messages.text.clone(),
|
processed_messages.text.clone(),
|
||||||
token_ids,
|
token_ids,
|
||||||
processed_messages.multimodal_inputs,
|
processed_messages.multimodal_inputs,
|
||||||
@@ -561,17 +593,228 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tool constraints for structured generation
|
/// Generate tool constraints for structured generation
|
||||||
|
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
|
||||||
fn generate_tool_constraints(
|
fn generate_tool_constraints(
|
||||||
&self,
|
&self,
|
||||||
_tools: &[Tool],
|
tools: &[Tool],
|
||||||
_tool_choice: &Option<ToolChoice>,
|
tool_choice: &Option<ToolChoice>,
|
||||||
model: &str,
|
_model: &str,
|
||||||
) -> Option<(String, String)> {
|
) -> Option<(String, String)> {
|
||||||
let _parser = self.tool_parser_registry.get_parser(model)?;
|
let choice = tool_choice.as_ref()?;
|
||||||
// TODO: Implement actual constraint generation logic
|
|
||||||
// For now, return None as this is placeholder implementation
|
match choice {
|
||||||
|
// Specific function: Return parameters schema directly
|
||||||
|
// tools should already be filtered to contain only the specific function
|
||||||
|
ToolChoice::Function { .. } => {
|
||||||
|
if tools.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let tool = &tools[0];
|
||||||
|
|
||||||
|
// Return the tool's parameters schema directly (not wrapped in array)
|
||||||
|
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
|
||||||
|
Some(("json_schema".to_string(), params_schema))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required: Array of tool calls with minItems: 1
|
||||||
|
ToolChoice::Value(ToolChoiceValue::Required) => {
|
||||||
|
let schema = self.build_required_array_schema(tools)?;
|
||||||
|
Some(("json_schema".to_string(), schema))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedTools with required mode: tools are already filtered
|
||||||
|
ToolChoice::AllowedTools { mode, .. } => {
|
||||||
|
if mode == "required" {
|
||||||
|
if tools.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let schema = self.build_required_array_schema(tools)?;
|
||||||
|
Some(("json_schema".to_string(), schema))
|
||||||
|
} else {
|
||||||
|
// "auto" mode - no constraint needed
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// "auto" or "none" - no constraint
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build JSON schema for required tool calls (array with minItems: 1)
|
||||||
|
/// Includes $defs consolidation from all tools (matching Python's behavior)
|
||||||
|
fn build_required_array_schema(&self, tools: &[Tool]) -> Option<String> {
|
||||||
|
// Build anyOf schemas for each tool
|
||||||
|
let mut any_of_schemas = Vec::new();
|
||||||
|
for tool in tools {
|
||||||
|
let tool_schema = json!({
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [tool.function.name]
|
||||||
|
},
|
||||||
|
"parameters": tool.function.parameters
|
||||||
|
},
|
||||||
|
"required": ["name", "parameters"]
|
||||||
|
});
|
||||||
|
any_of_schemas.push(tool_schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
|
||||||
|
let mut all_defs: HashMap<String, Value> = HashMap::new();
|
||||||
|
for tool in tools {
|
||||||
|
if let Value::Object(params) = &tool.function.parameters {
|
||||||
|
if let Some(Value::Object(defs)) = params.get("$defs") {
|
||||||
|
for (def_name, def_schema) in defs {
|
||||||
|
if let Some(existing) = all_defs.get(def_name) {
|
||||||
|
// Check for conflicts
|
||||||
|
if existing != def_schema {
|
||||||
|
error!(
|
||||||
|
"Tool definition '{}' has multiple schemas, which is not supported",
|
||||||
|
def_name
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
all_defs.insert(def_name.clone(), def_schema.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the full array schema
|
||||||
|
let mut array_schema = json!({
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"anyOf": any_of_schemas
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add $defs if any were found (matching Python's behavior)
|
||||||
|
if !all_defs.is_empty() {
|
||||||
|
if let Value::Object(ref mut schema_obj) = array_schema {
|
||||||
|
let defs_value =
|
||||||
|
Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
|
||||||
|
schema_obj.insert("$defs".to_string(), defs_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_json::to_string(&array_schema).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse tool calls from JSON schema constrained response
|
||||||
|
fn parse_json_schema_response(
|
||||||
|
&self,
|
||||||
|
processed_text: &str,
|
||||||
|
tool_choice: &Option<ToolChoice>,
|
||||||
|
) -> (Option<Vec<ToolCall>>, String) {
|
||||||
|
match tool_choice {
|
||||||
|
Some(ToolChoice::Function { function, .. }) => {
|
||||||
|
// Specific function: Parse parameters directly
|
||||||
|
match serde_json::from_str::<Value>(processed_text) {
|
||||||
|
Ok(params) => {
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: FunctionCallResponse {
|
||||||
|
name: function.name.clone(),
|
||||||
|
arguments: Some(
|
||||||
|
serde_json::to_string(¶ms)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
(Some(vec![tool_call]), String::new())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to parse specific function parameters: {}", e);
|
||||||
|
(None, processed_text.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(ToolChoice::Value(ToolChoiceValue::Required))
|
||||||
|
| Some(ToolChoice::AllowedTools { .. }) => {
|
||||||
|
// Required mode: Parse array of tool calls
|
||||||
|
match serde_json::from_str::<Vec<Value>>(processed_text) {
|
||||||
|
Ok(parsed_array) => {
|
||||||
|
let spec_tool_calls: Vec<ToolCall> = parsed_array
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, item)| {
|
||||||
|
let obj = item.as_object()?;
|
||||||
|
let name = obj.get("name")?.as_str()?.to_string();
|
||||||
|
let parameters = obj.get("parameters")?;
|
||||||
|
|
||||||
|
Some(ToolCall {
|
||||||
|
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: FunctionCallResponse {
|
||||||
|
name,
|
||||||
|
arguments: Some(
|
||||||
|
serde_json::to_string(parameters)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(Some(spec_tool_calls), String::new())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to parse required tool call array: {}", e);
|
||||||
|
(None, processed_text.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => (None, processed_text.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse tool calls using model-specific parser
|
||||||
|
async fn parse_with_model_parser(
|
||||||
|
&self,
|
||||||
|
processed_text: &str,
|
||||||
|
model: &str,
|
||||||
|
) -> (Option<Vec<ToolCall>>, String) {
|
||||||
|
let Some(parser) = self.tool_parser_registry.get_parser(model) else {
|
||||||
|
return (None, processed_text.to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
if !parser.detect_format(processed_text) {
|
||||||
|
return (None, processed_text.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
match parser.parse_complete(processed_text).await {
|
||||||
|
Ok((normal_text, parsed_tool_calls)) => {
|
||||||
|
if parsed_tool_calls.is_empty() {
|
||||||
|
return (None, normal_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let spec_tool_calls = parsed_tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.id,
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: FunctionCallResponse {
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: Some(
|
||||||
|
serde_json::to_string(&tc.function.arguments)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(Some(spec_tool_calls), normal_text)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Tool call parsing error: {}", e);
|
||||||
|
(None, processed_text.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Resolve the generate input into optional original text and token IDs
|
/// Resolve the generate input into optional original text and token IDs
|
||||||
fn resolve_generate_input(
|
fn resolve_generate_input(
|
||||||
@@ -1130,36 +1373,21 @@ impl GrpcRouter {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if tool_choice_enabled && original_request.tools.is_some() {
|
if tool_choice_enabled && original_request.tools.is_some() {
|
||||||
if let Some(parser) = self
|
// Check if JSON schema constraint was used (specific function or required mode)
|
||||||
.tool_parser_registry
|
let used_json_schema = match &original_request.tool_choice {
|
||||||
.get_parser(&original_request.model)
|
Some(ToolChoice::Function { .. }) => true,
|
||||||
{
|
Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true,
|
||||||
match parser.parse_complete(&processed_text).await {
|
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
|
||||||
Ok((normal_text, parsed_tool_calls)) => {
|
_ => false,
|
||||||
if !parsed_tool_calls.is_empty() {
|
};
|
||||||
let spec_tool_calls = parsed_tool_calls
|
|
||||||
.into_iter()
|
if used_json_schema {
|
||||||
.map(|tc| crate::protocols::spec::ToolCall {
|
(tool_calls, processed_text) =
|
||||||
id: tc.id,
|
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
|
||||||
tool_type: "function".to_string(),
|
} else {
|
||||||
function: crate::protocols::spec::FunctionCallResponse {
|
(tool_calls, processed_text) = self
|
||||||
name: tc.function.name,
|
.parse_with_model_parser(&processed_text, &original_request.model)
|
||||||
arguments: Some(
|
.await;
|
||||||
serde_json::to_string(&tc.function.arguments)
|
|
||||||
.unwrap_or_else(|_| "{}".to_string()),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
tool_calls = Some(spec_tool_calls);
|
|
||||||
processed_text = normal_text;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Tool call parsing error: {}", e);
|
|
||||||
// Continue without tool calls rather than failing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user