router-grpc: Add tools processing and other paramters for apply_chat_template (#10877)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::{context, machinery, Environment, Value};
|
||||
use serde_json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Chat template content format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -288,21 +289,25 @@ fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
|
||||
}
|
||||
|
||||
/// Parameters for chat template application
|
||||
#[derive(Default)]
|
||||
pub struct ChatTemplateParams<'a> {
|
||||
pub add_generation_prompt: bool,
|
||||
pub continue_final_message: bool,
|
||||
pub tools: Option<&'a [serde_json::Value]>,
|
||||
pub documents: Option<&'a [serde_json::Value]>,
|
||||
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatTemplateProcessor {
|
||||
/// Create a new chat template processor
|
||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
ChatTemplateProcessor {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
}
|
||||
pub fn new(template: String) -> Self {
|
||||
ChatTemplateProcessor { template }
|
||||
}
|
||||
|
||||
/// Apply the chat template to a list of messages
|
||||
@@ -313,8 +318,12 @@ impl ChatTemplateProcessor {
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
params: ChatTemplateParams,
|
||||
) -> Result<String> {
|
||||
// Validate incompatible options
|
||||
if params.continue_final_message && params.add_generation_prompt {
|
||||
return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."));
|
||||
}
|
||||
let mut env = Environment::new();
|
||||
|
||||
// Register the template
|
||||
@@ -326,17 +335,29 @@ impl ChatTemplateProcessor {
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
|
||||
// Convert messages to minijinja::Value (messages already processed by router)
|
||||
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
|
||||
|
||||
// Render the template directly with the provided values
|
||||
let base_context = context! {
|
||||
messages => &minijinja_messages,
|
||||
add_generation_prompt => params.add_generation_prompt,
|
||||
tools => params.tools,
|
||||
documents => params.documents,
|
||||
};
|
||||
|
||||
// Merge with template_kwargs if provided
|
||||
let ctx = if let Some(kwargs) = params.template_kwargs {
|
||||
context! {
|
||||
..base_context,
|
||||
..Value::from_serialize(kwargs)
|
||||
}
|
||||
} else {
|
||||
base_context
|
||||
};
|
||||
|
||||
// Render the template
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => minijinja_messages,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
})
|
||||
.render(&ctx)
|
||||
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||
|
||||
Ok(rendered)
|
||||
|
||||
@@ -4,7 +4,8 @@ use anyhow::{Error, Result};
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||
ChatTemplateProcessor,
|
||||
};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
@@ -165,16 +166,11 @@ impl HuggingFaceTokenizer {
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
params: ChatTemplateParams,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.clone(),
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
let processor = ChatTemplateProcessor::new(template.clone());
|
||||
processor.apply_chat_template(messages, params)
|
||||
} else {
|
||||
Err(Error::msg(
|
||||
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
|
||||
|
||||
Reference in New Issue
Block a user