router-grpc: Add tools processing and other paramters for apply_chat_template (#10877)
This commit is contained in:
@@ -27,7 +27,7 @@ use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
use serde_json::Value;
|
||||
|
||||
// Data structures for processing
|
||||
@@ -300,12 +300,87 @@ impl GrpcRouter {
|
||||
{
|
||||
// Get content format and transform messages accordingly
|
||||
let content_format = hf_tokenizer.chat_template_content_format();
|
||||
let transformed_messages =
|
||||
Self::transform_messages_for_content_format(&request.messages, content_format)?;
|
||||
let mut transformed_messages =
|
||||
Self::process_content_format(&request.messages, content_format)?;
|
||||
|
||||
hf_tokenizer
|
||||
.apply_chat_template(&transformed_messages, true)
|
||||
.map_err(|e| format!("Failed to apply chat template: {}", e))?
|
||||
// Process tool call arguments in assistant messages
|
||||
Self::process_tool_call_arguments(&mut transformed_messages)?;
|
||||
|
||||
// Convert tools to JSON values for template processing
|
||||
let tools_json: Option<Vec<serde_json::Value>> = request
|
||||
.tools
|
||||
.as_ref()
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.map(serde_json::to_value)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()
|
||||
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
|
||||
|
||||
// Build template kwargs, merging reasoning_effort if present
|
||||
let mut combined_template_kwargs = std::collections::HashMap::new();
|
||||
|
||||
// Add reasoning_effort if present (like Python does)
|
||||
if let Some(reasoning_effort) = &request.reasoning_effort {
|
||||
combined_template_kwargs.insert(
|
||||
"reasoning_effort".to_string(),
|
||||
serde_json::Value::String(reasoning_effort.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
// Add any additional template kwargs from request
|
||||
if let Some(template_kwargs) = &request.chat_template_kwargs {
|
||||
for (key, value) in template_kwargs {
|
||||
combined_template_kwargs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let final_template_kwargs = if combined_template_kwargs.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(&combined_template_kwargs)
|
||||
};
|
||||
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
continue_final_message: request.continue_final_message,
|
||||
tools: tools_json.as_deref(),
|
||||
template_kwargs: final_template_kwargs,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Handle assistant prefix for continue_final_message
|
||||
let assistant_prefix = if request.continue_final_message
|
||||
&& !transformed_messages.is_empty()
|
||||
&& transformed_messages
|
||||
.last()
|
||||
.and_then(|msg| msg.get("role"))
|
||||
.and_then(|v| v.as_str())
|
||||
== Some("assistant")
|
||||
{
|
||||
// Pop the last message to handle it separately
|
||||
let last_msg = transformed_messages.pop().unwrap();
|
||||
last_msg
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Apply chat template with the (now possibly shorter) list of messages
|
||||
let rendered = hf_tokenizer
|
||||
.apply_chat_template(&transformed_messages, params)
|
||||
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
|
||||
|
||||
// Append assistant prefix if we have one
|
||||
if let Some(prefix) = assistant_prefix {
|
||||
format!("{}{}", rendered, prefix)
|
||||
} else {
|
||||
rendered
|
||||
}
|
||||
} else {
|
||||
return Err(
|
||||
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
|
||||
@@ -322,8 +397,8 @@ impl GrpcRouter {
|
||||
})
|
||||
}
|
||||
|
||||
/// Transform messages based on content format for ANY message type
|
||||
fn transform_messages_for_content_format(
|
||||
/// Process messages based on content format for ANY message type
|
||||
fn process_content_format(
|
||||
messages: &[crate::protocols::spec::ChatMessage],
|
||||
content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat,
|
||||
) -> Result<Vec<serde_json::Value>, String> {
|
||||
@@ -394,6 +469,49 @@ impl GrpcRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Process tool call arguments in messages
|
||||
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
|
||||
fn process_tool_call_arguments(messages: &mut [serde_json::Value]) -> Result<(), String> {
|
||||
for msg in messages {
|
||||
// Early return if not assistant message
|
||||
let role = msg.get("role").and_then(|v| v.as_str());
|
||||
if role != Some("assistant") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Early return if no tool_calls
|
||||
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Process each tool call's arguments
|
||||
for call in tool_calls {
|
||||
let Some(function) = call.get_mut("function") else {
|
||||
continue;
|
||||
};
|
||||
let Some(args) = function.get_mut("arguments") else {
|
||||
continue;
|
||||
};
|
||||
let Some(args_str) = args.as_str() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Parse JSON string to object (like Python json.loads)
|
||||
match serde_json::from_str::<serde_json::Value>(args_str) {
|
||||
Ok(parsed) => *args = parsed,
|
||||
Err(e) => {
|
||||
return Err(format!(
|
||||
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
|
||||
args_str, e
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build gRPC SamplingParams from OpenAI request
|
||||
fn build_grpc_sampling_params(
|
||||
&self,
|
||||
@@ -410,6 +528,19 @@ impl GrpcRouter {
|
||||
.or(request.max_tokens)
|
||||
.map(|v| v as i32);
|
||||
|
||||
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
|
||||
let skip_special_tokens = if request.tools.is_some() {
|
||||
match &request.tool_choice {
|
||||
Some(crate::protocols::spec::ToolChoice::Value(
|
||||
crate::protocols::spec::ToolChoiceValue::None,
|
||||
)) => request.skip_special_tokens,
|
||||
Some(_) => false, // tool_choice is not "none"
|
||||
None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present
|
||||
}
|
||||
} else {
|
||||
request.skip_special_tokens
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
Ok(proto::SamplingParams {
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
@@ -422,7 +553,7 @@ impl GrpcRouter {
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
||||
skip_special_tokens: request.skip_special_tokens,
|
||||
skip_special_tokens,
|
||||
n: request.n.unwrap_or(1) as i32,
|
||||
structural_tag: structural_tag.unwrap_or_default(),
|
||||
constraint: self.build_constraint(request)?,
|
||||
@@ -700,11 +831,9 @@ mod tests {
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
@@ -735,11 +864,9 @@ mod tests {
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::OpenAI,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
@@ -764,11 +891,9 @@ mod tests {
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
@@ -791,11 +916,9 @@ mod tests {
|
||||
reasoning_content: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
@@ -832,11 +955,9 @@ mod tests {
|
||||
},
|
||||
];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
@@ -862,11 +983,9 @@ mod tests {
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
@@ -902,22 +1021,18 @@ mod tests {
|
||||
];
|
||||
|
||||
// Test String format
|
||||
let result_string = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
let result_string =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result_string.len(), 2);
|
||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||
|
||||
// Test OpenAI format
|
||||
let result_openai = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::OpenAI,
|
||||
)
|
||||
.unwrap();
|
||||
let result_openai =
|
||||
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result_openai.len(), 2);
|
||||
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
|
||||
|
||||
Reference in New Issue
Block a user