router-grpc: Support jinja chat template content format detection (#10832)

This commit is contained in:
Chang Su
2025-09-24 11:45:01 -07:00
committed by GitHub
parent adba172fd1
commit 9209b209be
9 changed files with 1276 additions and 353 deletions

View File

@@ -4,39 +4,291 @@
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
use minijinja::{context, Environment, Value};
use serde::{Deserialize, Serialize};
use minijinja::{context, machinery, Environment, Value};
use serde_json;
/// Represents a chat message with role and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
/// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChatTemplateContentFormat {
/// Content is a simple string
String,
/// Content is a list of structured parts (OpenAI format)
OpenAI,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
impl Default for ChatTemplateContentFormat {
fn default() -> Self {
Self::String
}
}
impl std::fmt::Display for ChatTemplateContentFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::String => write!(f, "string"),
Self::OpenAI => write!(f, "openai"),
}
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
/// Detect the content format expected by a Jinja2 chat template
///
/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
/// which uses AST parsing to look for content iteration patterns.
///
/// Returns:
/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
/// - ChatTemplateContentFormat::String if template expects simple string content
pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
// Use AST-based detection (enabled by default)
if let Some(format) = detect_format_with_ast(template) {
return format;
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
// Default to string format if AST parsing fails
ChatTemplateContentFormat::String
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
/// AST-based detection using minijinja's unstable machinery
/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig};
use minijinja::syntax::SyntaxConfig;
// Parse the template into AST
let ast = match parse(
template,
"template",
SyntaxConfig {},
WhitespaceConfig::default(),
) {
Ok(ast) => ast,
Err(_) => return Some(ChatTemplateContentFormat::String),
};
// Traverse AST looking for patterns that indicate OpenAI format
let has_iteration = find_content_iteration_in_ast(&ast);
let has_structure_checks = find_content_structure_checks_in_ast(&ast);
let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast);
if has_iteration || has_structure_checks || has_assignment_patterns {
Some(ChatTemplateContentFormat::OpenAI)
} else {
Some(ChatTemplateContentFormat::String)
}
}
/// Chat template processor using Jinja2
/// Find content iteration patterns in AST
/// Implements the same logic as SGLang's AST traversal
fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match ast {
Stmt::Template(template) => {
// Recursively check all children
template
.children
.iter()
.any(|child| find_content_iteration_in_ast(child))
}
Stmt::ForLoop(for_loop) => {
// Check if this for-loop iterates over message content
is_var_or_elems_access(&for_loop.iter, "message", "content") ||
is_var_or_elems_access(&for_loop.iter, "msg", "content") ||
is_var_or_elems_access(&for_loop.iter, "m", "content") ||
// Also check the body for nested loops
for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt))
}
Stmt::IfCond(if_cond) => {
// Check true and false branches
if_cond
.true_body
.iter()
.any(|stmt| find_content_iteration_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_content_iteration_in_ast(stmt))
}
_ => false, // Other statement types don't contain loops
}
}
/// Check if expression accesses varname['key'] or varname.key
/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes
fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool {
use machinery::ast::Expr;
match expr {
// Check for attribute access: varname.key
Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key,
// Check for item access: varname['key'] or varname["key"]
Expr::GetItem(getitem) => {
is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key)
}
// Handle filters and tests that might wrap the access
Expr::Filter(filter) => {
if let Some(ref expr) = filter.expr {
is_var_or_elems_access(expr, varname, key)
} else {
false
}
}
Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key),
_ => false,
}
}
/// Check if expression is a variable access (like {{ varname }})
/// Implements SGLang's _is_var_access logic
fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool {
matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname)
}
/// Check if expression is a constant string with the given value
fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool {
matches!(expr, machinery::ast::Expr::Const(const_expr)
if const_expr.value.as_str() == Some(value))
}
/// Find content structure checks in AST (like content[0], content|length)
fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match ast {
Stmt::Template(template) => template
.children
.iter()
.any(|child| find_content_structure_checks_in_ast(child)),
Stmt::ForLoop(for_loop) => for_loop
.body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt)),
Stmt::IfCond(if_cond) => {
// Check if condition has content structure checks
has_content_structure_check_expr(&if_cond.expr)
|| if_cond
.true_body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt))
}
Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr),
_ => false,
}
}
/// Find variable assignment patterns like set content = message['content']
fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match ast {
Stmt::Template(template) => template
.children
.iter()
.any(|child| find_variable_assignment_patterns_in_ast(child)),
Stmt::ForLoop(for_loop) => {
// Check if this for-loop body contains both assignment and iteration
let has_assignment = for_loop
.body
.iter()
.any(|stmt| is_content_assignment_stmt(stmt));
let has_iteration = for_loop.body.iter().any(|stmt| {
is_content_variable_iteration(stmt)
|| matches!(stmt, Stmt::IfCond(if_cond) if
if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) ||
if_cond.false_body.iter().any(|s| is_content_variable_iteration(s))
)
});
(has_assignment && has_iteration)
|| for_loop
.body
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
}
Stmt::IfCond(if_cond) => {
if_cond
.true_body
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
}
_ => false,
}
}
/// Check if expression has content structure checks (index access, length, etc.)
fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool {
use machinery::ast::Expr;
match expr {
// Check for content[0] - index access
Expr::GetItem(getitem) => {
is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr)
}
// Check for content|length - filter with length
Expr::Filter(filter) => {
if let Some(ref filter_expr) = filter.expr {
is_content_access(filter_expr) && filter.name == "length"
} else {
false
}
}
// Check for content is sequence/iterable
Expr::Test(test) => {
is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable")
}
_ => false,
}
}
/// Check if statement assigns message content to a variable
fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match stmt {
Stmt::Set(set_stmt) => {
// Check if this is setting content = message['content']
is_var_access(&set_stmt.target, "content")
&& is_var_or_elems_access(&set_stmt.expr, "message", "content")
}
_ => false,
}
}
/// Check if statement iterates over content variable
fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool {
use machinery::ast::{Expr, Stmt};
match stmt {
Stmt::ForLoop(for_loop) => {
// Check if iterating over a variable named "content"
matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content")
}
_ => false,
}
}
/// Check if expression accesses content (message.content, message['content'], etc.)
fn is_content_access(expr: &machinery::ast::Expr) -> bool {
is_var_or_elems_access(expr, "message", "content")
|| is_var_or_elems_access(expr, "msg", "content")
|| is_var_or_elems_access(expr, "m", "content")
}
/// Check if expression is a numeric constant (for index access)
fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
}
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
pub struct ChatTemplateProcessor {
template: String,
bos_token: Option<String>,
@@ -57,9 +309,10 @@ impl ChatTemplateProcessor {
///
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// but returns the formatted string instead of token IDs.
/// Messages should be pre-processed into the format expected by the template.
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
messages: &[serde_json::Value],
add_generation_prompt: bool,
) -> Result<String> {
let mut env = Environment::new();
@@ -73,21 +326,13 @@ impl ChatTemplateProcessor {
.get_template("chat")
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
// Convert messages to a format Jinja can work with
let messages_value: Vec<Value> = messages
.iter()
.map(|msg| {
context! {
role => msg.role.clone(),
content => msg.content.clone()
}
})
.collect();
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
// Render the template
// Render the template directly with the provided values
let rendered = tmpl
.render(context! {
messages => messages_value,
messages => minijinja_messages,
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default()
@@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
#[test]
fn test_simple_chat_template() {
// Simple template that formats messages
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
assert!(result.contains("system: You are helpful"));
assert!(result.contains("user: Hello"));
assert!(result.contains("assistant:"));
}
#[test]
fn test_chat_template_with_tokens() {
// Template that uses special tokens
let template = r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Test")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert!(result.contains("<s>"));
assert!(result.contains("</s>"));
}
}

View File

@@ -1,11 +1,14 @@
use std::collections::HashMap;
use anyhow::{Error, Result};
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
};
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result};
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
@@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer {
vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<TokenIdType, String>,
chat_template: Option<String>,
/// Detected chat template content format (computed once at initialization)
content_format: ChatTemplateContentFormat,
}
impl HuggingFaceTokenizer {
@@ -49,12 +54,20 @@ impl HuggingFaceTokenizer {
Self::load_chat_template(file_path)
};
// Detect content format once at initialization
let content_format = if let Some(ref template) = chat_template {
detect_chat_template_content_format(template)
} else {
ChatTemplateContentFormat::String // Default if no template
};
Ok(HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
chat_template,
content_format,
})
}
@@ -73,6 +86,7 @@ impl HuggingFaceTokenizer {
vocab,
reverse_vocab,
chat_template: None,
content_format: ChatTemplateContentFormat::String, // Default
}
}
@@ -135,13 +149,22 @@ impl HuggingFaceTokenizer {
/// Set or override the chat template
pub fn set_chat_template(&mut self, template: String) {
// Detect format for the new template
self.content_format = detect_chat_template_content_format(&template);
self.chat_template = Some(template);
}
/// Get the content format expected by the chat template
pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
self.content_format
}
/// Apply chat template if available
///
/// Takes transformed JSON Values (already transformed based on content format)
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
messages: &[serde_json::Value],
add_generation_prompt: bool,
) -> Result<String> {
if let Some(ref template) = self.chat_template {
@@ -150,17 +173,15 @@ impl HuggingFaceTokenizer {
self.special_tokens.bos_token.clone(),
self.special_tokens.eos_token.clone(),
);
processor.apply_chat_template(messages, add_generation_prompt)
} else {
// Fallback to simple formatting if no template is available
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
Err(Error::msg(
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
argument was passed! For information about writing templates and setting the \
tokenizer.chat_template attribute, please see the documentation at \
https://huggingface.co/docs/transformers/main/en/chat_templating"
))
}
}
}
@@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer {
#[cfg(test)]
mod tests {
use super::ChatMessage;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
// Note: Actual tokenizer tests would require a real tokenizer file
// These would be integration tests rather than unit tests
}

View File

@@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
pub use huggingface::HuggingFaceTokenizer;
pub use chat_template::ChatMessage;
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations