[router][grpc] Remove continue_final_message in ChatTemplateParams and add minijinja-contrib (#11882)
This commit is contained in:
@@ -64,6 +64,7 @@ anyhow = "1.0"
|
|||||||
tokenizers = { version = "0.22.0" }
|
tokenizers = { version = "0.22.0" }
|
||||||
tiktoken-rs = { version = "0.7.0" }
|
tiktoken-rs = { version = "0.7.0" }
|
||||||
minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins"] }
|
minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins"] }
|
||||||
|
minijinja-contrib = { version = "2.0", features = ["pycompat"] }
|
||||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||||
hf-hub = { version = "0.4.3", features = ["tokio"] }
|
hf-hub = { version = "0.4.3", features = ["tokio"] }
|
||||||
rmcp = { version = "0.6.3", features = ["client", "server",
|
rmcp = { version = "0.6.3", features = ["client", "server",
|
||||||
|
|||||||
@@ -382,7 +382,6 @@ pub fn process_chat_messages(
|
|||||||
|
|
||||||
let params = ChatTemplateParams {
|
let params = ChatTemplateParams {
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
continue_final_message: request.continue_final_message,
|
|
||||||
tools: tools_json.as_deref(),
|
tools: tools_json.as_deref(),
|
||||||
template_kwargs: final_template_kwargs,
|
template_kwargs: final_template_kwargs,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|||||||
@@ -3,12 +3,16 @@
|
|||||||
//! This module provides functionality to apply chat templates to messages,
|
//! This module provides functionality to apply chat templates to messages,
|
||||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fs};
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use minijinja::{
|
use minijinja::{
|
||||||
context,
|
context,
|
||||||
machinery::ast::{Expr, Stmt},
|
machinery::{
|
||||||
|
ast::{Expr, Stmt},
|
||||||
|
parse, WhitespaceConfig,
|
||||||
|
},
|
||||||
|
syntax::SyntaxConfig,
|
||||||
Environment, Value,
|
Environment, Value,
|
||||||
};
|
};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
@@ -323,11 +327,6 @@ impl<'a> Detector<'a> {
|
|||||||
/// AST-based detection using minijinja's unstable machinery
|
/// AST-based detection using minijinja's unstable machinery
|
||||||
/// Single-pass detector with scope tracking
|
/// Single-pass detector with scope tracking
|
||||||
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
||||||
use minijinja::{
|
|
||||||
machinery::{parse, WhitespaceConfig},
|
|
||||||
syntax::SyntaxConfig,
|
|
||||||
};
|
|
||||||
|
|
||||||
let ast = match parse(
|
let ast = match parse(
|
||||||
template,
|
template,
|
||||||
"template",
|
"template",
|
||||||
@@ -350,7 +349,6 @@ fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
|||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct ChatTemplateParams<'a> {
|
pub struct ChatTemplateParams<'a> {
|
||||||
pub add_generation_prompt: bool,
|
pub add_generation_prompt: bool,
|
||||||
pub continue_final_message: bool,
|
|
||||||
pub tools: Option<&'a [serde_json::Value]>,
|
pub tools: Option<&'a [serde_json::Value]>,
|
||||||
pub documents: Option<&'a [serde_json::Value]>,
|
pub documents: Option<&'a [serde_json::Value]>,
|
||||||
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
|
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
|
||||||
@@ -377,16 +375,15 @@ impl ChatTemplateProcessor {
|
|||||||
messages: &[serde_json::Value],
|
messages: &[serde_json::Value],
|
||||||
params: ChatTemplateParams,
|
params: ChatTemplateParams,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
// Validate incompatible options
|
|
||||||
if params.continue_final_message && params.add_generation_prompt {
|
|
||||||
return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."));
|
|
||||||
}
|
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
|
|
||||||
// Register the template
|
// Register the template
|
||||||
env.add_template("chat", &self.template)
|
env.add_template("chat", &self.template)
|
||||||
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
||||||
|
|
||||||
|
// Enable Python method compatibility (e.g., str.startswith, str.endswith)
|
||||||
|
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
|
||||||
|
|
||||||
// Get the template
|
// Get the template
|
||||||
let tmpl = env
|
let tmpl = env
|
||||||
.get_template("chat")
|
.get_template("chat")
|
||||||
@@ -423,8 +420,6 @@ impl ChatTemplateProcessor {
|
|||||||
|
|
||||||
/// Load chat template from tokenizer config JSON
|
/// Load chat template from tokenizer config JSON
|
||||||
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
let content = fs::read_to_string(config_path)?;
|
let content = fs::read_to_string(config_path)?;
|
||||||
let config: serde_json::Value = serde_json::from_str(&content)?;
|
let config: serde_json::Value = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user