router-grpc: Support jinja chat template content format detection (#10832)
This commit is contained in:
@@ -4,39 +4,291 @@
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use minijinja::{context, machinery, Environment, Value};
|
||||
use serde_json;
|
||||
|
||||
/// Represents a chat message with role and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
/// Chat template content format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ChatTemplateContentFormat {
|
||||
/// Content is a simple string
|
||||
String,
|
||||
/// Content is a list of structured parts (OpenAI format)
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
impl Default for ChatTemplateContentFormat {
|
||||
fn default() -> Self {
|
||||
Self::String
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ChatTemplateContentFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::String => write!(f, "string"),
|
||||
Self::OpenAI => write!(f, "openai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
/// Detect the content format expected by a Jinja2 chat template
|
||||
///
|
||||
/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
|
||||
/// which uses AST parsing to look for content iteration patterns.
|
||||
///
|
||||
/// Returns:
|
||||
/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
|
||||
/// - ChatTemplateContentFormat::String if template expects simple string content
|
||||
pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
|
||||
// Use AST-based detection (enabled by default)
|
||||
if let Some(format) = detect_format_with_ast(template) {
|
||||
return format;
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
// Default to string format if AST parsing fails
|
||||
ChatTemplateContentFormat::String
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
/// AST-based detection using minijinja's unstable machinery
|
||||
/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions
|
||||
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
||||
use minijinja::machinery::{parse, WhitespaceConfig};
|
||||
use minijinja::syntax::SyntaxConfig;
|
||||
|
||||
// Parse the template into AST
|
||||
let ast = match parse(
|
||||
template,
|
||||
"template",
|
||||
SyntaxConfig {},
|
||||
WhitespaceConfig::default(),
|
||||
) {
|
||||
Ok(ast) => ast,
|
||||
Err(_) => return Some(ChatTemplateContentFormat::String),
|
||||
};
|
||||
|
||||
// Traverse AST looking for patterns that indicate OpenAI format
|
||||
let has_iteration = find_content_iteration_in_ast(&ast);
|
||||
let has_structure_checks = find_content_structure_checks_in_ast(&ast);
|
||||
let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast);
|
||||
|
||||
if has_iteration || has_structure_checks || has_assignment_patterns {
|
||||
Some(ChatTemplateContentFormat::OpenAI)
|
||||
} else {
|
||||
Some(ChatTemplateContentFormat::String)
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2
|
||||
/// Find content iteration patterns in AST
|
||||
/// Implements the same logic as SGLang's AST traversal
|
||||
fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => {
|
||||
// Recursively check all children
|
||||
template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_content_iteration_in_ast(child))
|
||||
}
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if this for-loop iterates over message content
|
||||
is_var_or_elems_access(&for_loop.iter, "message", "content") ||
|
||||
is_var_or_elems_access(&for_loop.iter, "msg", "content") ||
|
||||
is_var_or_elems_access(&for_loop.iter, "m", "content") ||
|
||||
// Also check the body for nested loops
|
||||
for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
}
|
||||
Stmt::IfCond(if_cond) => {
|
||||
// Check true and false branches
|
||||
if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
}
|
||||
_ => false, // Other statement types don't contain loops
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression accesses varname['key'] or varname.key
|
||||
/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes
|
||||
fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool {
|
||||
use machinery::ast::Expr;
|
||||
|
||||
match expr {
|
||||
// Check for attribute access: varname.key
|
||||
Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key,
|
||||
// Check for item access: varname['key'] or varname["key"]
|
||||
Expr::GetItem(getitem) => {
|
||||
is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key)
|
||||
}
|
||||
// Handle filters and tests that might wrap the access
|
||||
Expr::Filter(filter) => {
|
||||
if let Some(ref expr) = filter.expr {
|
||||
is_var_or_elems_access(expr, varname, key)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression is a variable access (like {{ varname }})
|
||||
/// Implements SGLang's _is_var_access logic
|
||||
fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname)
|
||||
}
|
||||
|
||||
/// Check if expression is a constant string with the given value
|
||||
fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Const(const_expr)
|
||||
if const_expr.value.as_str() == Some(value))
|
||||
}
|
||||
|
||||
/// Find content structure checks in AST (like content[0], content|length)
|
||||
fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_content_structure_checks_in_ast(child)),
|
||||
Stmt::ForLoop(for_loop) => for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt)),
|
||||
Stmt::IfCond(if_cond) => {
|
||||
// Check if condition has content structure checks
|
||||
has_content_structure_check_expr(&if_cond.expr)
|
||||
|| if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|
||||
}
|
||||
Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find variable assignment patterns like set content = message['content']
|
||||
fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_variable_assignment_patterns_in_ast(child)),
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if this for-loop body contains both assignment and iteration
|
||||
let has_assignment = for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| is_content_assignment_stmt(stmt));
|
||||
let has_iteration = for_loop.body.iter().any(|stmt| {
|
||||
is_content_variable_iteration(stmt)
|
||||
|| matches!(stmt, Stmt::IfCond(if_cond) if
|
||||
if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) ||
|
||||
if_cond.false_body.iter().any(|s| is_content_variable_iteration(s))
|
||||
)
|
||||
});
|
||||
|
||||
(has_assignment && has_iteration)
|
||||
|| for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
}
|
||||
Stmt::IfCond(if_cond) => {
|
||||
if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression has content structure checks (index access, length, etc.)
|
||||
fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool {
|
||||
use machinery::ast::Expr;
|
||||
|
||||
match expr {
|
||||
// Check for content[0] - index access
|
||||
Expr::GetItem(getitem) => {
|
||||
is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr)
|
||||
}
|
||||
// Check for content|length - filter with length
|
||||
Expr::Filter(filter) => {
|
||||
if let Some(ref filter_expr) = filter.expr {
|
||||
is_content_access(filter_expr) && filter.name == "length"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
// Check for content is sequence/iterable
|
||||
Expr::Test(test) => {
|
||||
is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if statement assigns message content to a variable
|
||||
fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match stmt {
|
||||
Stmt::Set(set_stmt) => {
|
||||
// Check if this is setting content = message['content']
|
||||
is_var_access(&set_stmt.target, "content")
|
||||
&& is_var_or_elems_access(&set_stmt.expr, "message", "content")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if statement iterates over content variable
|
||||
fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::{Expr, Stmt};
|
||||
|
||||
match stmt {
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if iterating over a variable named "content"
|
||||
matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression accesses content (message.content, message['content'], etc.)
|
||||
fn is_content_access(expr: &machinery::ast::Expr) -> bool {
|
||||
is_var_or_elems_access(expr, "message", "content")
|
||||
|| is_var_or_elems_access(expr, "msg", "content")
|
||||
|| is_var_or_elems_access(expr, "m", "content")
|
||||
}
|
||||
|
||||
/// Check if expression is a numeric constant (for index access)
|
||||
fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
@@ -57,9 +309,10 @@ impl ChatTemplateProcessor {
|
||||
///
|
||||
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||
/// but returns the formatted string instead of token IDs.
|
||||
/// Messages should be pre-processed into the format expected by the template.
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
@@ -73,21 +326,13 @@ impl ChatTemplateProcessor {
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert messages to a format Jinja can work with
|
||||
let messages_value: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
context! {
|
||||
role => msg.role.clone(),
|
||||
content => msg.content.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
|
||||
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
|
||||
|
||||
// Render the template
|
||||
// Render the template directly with the provided values
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => messages_value,
|
||||
messages => minijinja_messages,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
@@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
// Simple template that formats messages
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Error, Result};
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
@@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer {
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
chat_template: Option<String>,
|
||||
/// Detected chat template content format (computed once at initialization)
|
||||
content_format: ChatTemplateContentFormat,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
@@ -49,12 +54,20 @@ impl HuggingFaceTokenizer {
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
// Detect content format once at initialization
|
||||
let content_format = if let Some(ref template) = chat_template {
|
||||
detect_chat_template_content_format(template)
|
||||
} else {
|
||||
ChatTemplateContentFormat::String // Default if no template
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template,
|
||||
content_format,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,6 +86,7 @@ impl HuggingFaceTokenizer {
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template: None,
|
||||
content_format: ChatTemplateContentFormat::String, // Default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,13 +149,22 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
/// Set or override the chat template
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
// Detect format for the new template
|
||||
self.content_format = detect_chat_template_content_format(&template);
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Get the content format expected by the chat template
|
||||
pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
|
||||
self.content_format
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
///
|
||||
/// Takes transformed JSON Values (already transformed based on content format)
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
@@ -150,17 +173,15 @@ impl HuggingFaceTokenizer {
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
Err(Error::msg(
|
||||
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
|
||||
argument was passed! For information about writing templates and setting the \
|
||||
tokenizer.chat_template attribute, please see the documentation at \
|
||||
https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ChatMessage;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
// Note: Actual tokenizer tests would require a real tokenizer file
|
||||
// These would be integration tests rather than unit tests
|
||||
}
|
||||
|
||||
@@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
|
||||
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
pub use chat_template::ChatMessage;
|
||||
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
|
||||
Reference in New Issue
Block a user