router-grpc: Support jinja chat template content format detection (#10832)
This commit is contained in:
@@ -57,7 +57,7 @@ tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
tokenizers = { version = "0.22.0" }
|
||||
tiktoken-rs = { version = "0.7.0" }
|
||||
minijinja = { version = "2.0" }
|
||||
minijinja = { version = "2.0", features = ["unstable_machinery"] }
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
hf-hub = { version = "0.4.3", features = ["tokio"] }
|
||||
rmcp = { version = "0.6.3", features = ["client", "server",
|
||||
|
||||
@@ -1,5 +1,18 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
|
||||
@@ -7,27 +20,16 @@ use crate::core::{
|
||||
use crate::grpc::{proto, SglangSchedulerClient};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, ContentPart, ResponseFormat, StringOrArray,
|
||||
UserMessageContent,
|
||||
};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat, StringOrArray};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::tokenizer::{chat_template::ChatMessage as TokenizerChatMessage, traits::Tokenizer};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
||||
use serde_json::Value;
|
||||
|
||||
// Data structures for processing
|
||||
#[derive(Debug)]
|
||||
pub struct ProcessedMessages {
|
||||
@@ -290,16 +292,19 @@ impl GrpcRouter {
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> Result<ProcessedMessages, String> {
|
||||
let tokenizer_messages = self.convert_messages_for_tokenizer(&request.messages)?;
|
||||
|
||||
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
|
||||
let formatted_text = if let Some(hf_tokenizer) =
|
||||
self.tokenizer
|
||||
.as_any()
|
||||
.downcast_ref::<crate::tokenizer::HuggingFaceTokenizer>()
|
||||
{
|
||||
// Get content format and transform messages accordingly
|
||||
let content_format = hf_tokenizer.chat_template_content_format();
|
||||
let transformed_messages =
|
||||
Self::transform_messages_for_content_format(&request.messages, content_format)?;
|
||||
|
||||
hf_tokenizer
|
||||
.apply_chat_template(&tokenizer_messages, true)
|
||||
.apply_chat_template(&transformed_messages, true)
|
||||
.map_err(|e| format!("Failed to apply chat template: {}", e))?
|
||||
} else {
|
||||
return Err(
|
||||
@@ -317,46 +322,76 @@ impl GrpcRouter {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert spec ChatMessage enum to tokenizer ChatMessage struct
|
||||
fn convert_messages_for_tokenizer(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
) -> Result<Vec<TokenizerChatMessage>, String> {
|
||||
let mut converted = Vec::new();
|
||||
/// Transform messages based on content format for ANY message type
|
||||
fn transform_messages_for_content_format(
|
||||
messages: &[crate::protocols::spec::ChatMessage],
|
||||
content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat,
|
||||
) -> Result<Vec<serde_json::Value>, String> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
let mut message_json = serde_json::to_value(message)
|
||||
.map_err(|e| format!("Failed to serialize message: {}", e))?;
|
||||
|
||||
for message in messages {
|
||||
let tokenizer_msg = match message {
|
||||
ChatMessage::System { content, .. } => TokenizerChatMessage::new("system", content),
|
||||
ChatMessage::User { content, .. } => {
|
||||
let text_content = match content {
|
||||
UserMessageContent::Text(text) => text.clone(),
|
||||
UserMessageContent::Parts(parts) => {
|
||||
// Simple text extraction for now - multimodal is placeholder
|
||||
parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.as_str()),
|
||||
ContentPart::ImageUrl { .. } => None, // Skip images for now
|
||||
})
|
||||
.collect::<Vec<&str>>()
|
||||
.join(" ")
|
||||
}
|
||||
};
|
||||
TokenizerChatMessage::new("user", text_content)
|
||||
if let Some(obj) = message_json.as_object_mut() {
|
||||
if let Some(content_value) = obj.get_mut("content") {
|
||||
Self::transform_content_field(content_value, content_format);
|
||||
}
|
||||
}
|
||||
ChatMessage::Assistant { content, .. } => {
|
||||
// Simple content extraction - no special tool/reasoning formatting
|
||||
TokenizerChatMessage::new("assistant", content.as_deref().unwrap_or(""))
|
||||
|
||||
Ok(message_json)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Transform a single content field based on content format
|
||||
fn transform_content_field(
|
||||
content_value: &mut Value,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) {
|
||||
let Some(content_array) = content_value.as_array() else {
|
||||
return; // Not multimodal, keep as-is
|
||||
};
|
||||
|
||||
match content_format {
|
||||
ChatTemplateContentFormat::String => {
|
||||
// Extract and join text parts only
|
||||
let text_parts: Vec<String> = content_array
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
part.as_object()?
|
||||
.get("type")?
|
||||
.as_str()
|
||||
.filter(|&t| t == "text")
|
||||
.and_then(|_| part.as_object()?.get("text")?.as_str())
|
||||
.map(String::from)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !text_parts.is_empty() {
|
||||
*content_value = Value::String(text_parts.join(" "));
|
||||
}
|
||||
ChatMessage::Tool { content, .. } => TokenizerChatMessage::new("tool", content),
|
||||
ChatMessage::Function { content, .. } => {
|
||||
TokenizerChatMessage::new("function", content)
|
||||
}
|
||||
};
|
||||
converted.push(tokenizer_msg);
|
||||
}
|
||||
ChatTemplateContentFormat::OpenAI => {
|
||||
// Replace media URLs with simple type placeholders
|
||||
let processed_parts: Vec<Value> = content_array
|
||||
.iter()
|
||||
.map(|part| {
|
||||
part.as_object()
|
||||
.and_then(|obj| obj.get("type")?.as_str())
|
||||
.and_then(|type_str| match type_str {
|
||||
"image_url" => Some(serde_json::json!({"type": "image"})),
|
||||
"video_url" => Some(serde_json::json!({"type": "video"})),
|
||||
"audio_url" => Some(serde_json::json!({"type": "audio"})),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_else(|| part.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
*content_value = Value::Array(processed_parts);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
/// Build gRPC SamplingParams from OpenAI request
|
||||
@@ -636,3 +671,260 @@ impl RouterTrait for GrpcRouter {
|
||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent};
|
||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_string_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
ContentPart::Text {
|
||||
text: "World".to_string(),
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should flatten multimodal content to text only
|
||||
assert_eq!(
|
||||
transformed_message["content"].as_str().unwrap(),
|
||||
"Hello World"
|
||||
);
|
||||
assert_eq!(transformed_message["role"].as_str().unwrap(), "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_openai_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Describe this image:".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: Some("high".to_string()),
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::OpenAI,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should replace media URLs with simple type placeholders
|
||||
let content_array = transformed_message["content"].as_array().unwrap();
|
||||
assert_eq!(content_array.len(), 2);
|
||||
|
||||
// Text part should remain unchanged
|
||||
assert_eq!(content_array[0]["type"], "text");
|
||||
assert_eq!(content_array[0]["text"], "Describe this image:");
|
||||
|
||||
// Image part should be replaced with simple type placeholder
|
||||
assert_eq!(content_array[1], json!({"type": "image"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_simple_string_content() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Simple text message".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Simple string content should remain unchanged
|
||||
assert_eq!(
|
||||
transformed_message["content"].as_str().unwrap(),
|
||||
"Simple text message"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_assistant_message() {
|
||||
let messages = vec![ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Assistant response".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
assert_eq!(transformed_message["role"].as_str().unwrap(), "assistant");
|
||||
assert_eq!(
|
||||
transformed_message["content"].as_str().unwrap(),
|
||||
"Assistant response"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_multiple_messages() {
|
||||
let messages = vec![
|
||||
ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "System prompt".to_string(),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "User message".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// System message should remain unchanged
|
||||
assert_eq!(result[0]["role"].as_str().unwrap(), "system");
|
||||
assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt");
|
||||
|
||||
// User message should be flattened to text only
|
||||
assert_eq!(result[1]["role"].as_str().unwrap(), "user");
|
||||
assert_eq!(result[1]["content"].as_str().unwrap(), "User message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_empty_text_parts() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
}]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should keep original multimodal content when no text parts exist
|
||||
assert!(transformed_message["content"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_mixed_content_types() {
|
||||
// Test with both text and multimodal content
|
||||
let messages = vec![
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Plain text".to_string()),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "With image".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: Some("low".to_string()),
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Test String format
|
||||
let result_string = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::String,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result_string.len(), 2);
|
||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||
|
||||
// Test OpenAI format
|
||||
let result_openai = GrpcRouter::transform_messages_for_content_format(
|
||||
&messages,
|
||||
ChatTemplateContentFormat::OpenAI,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result_openai.len(), 2);
|
||||
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
|
||||
|
||||
let content_array = result_openai[1]["content"].as_array().unwrap();
|
||||
assert_eq!(content_array.len(), 2);
|
||||
assert_eq!(content_array[0]["type"], "text");
|
||||
assert_eq!(content_array[1], json!({"type": "image"}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,39 +4,291 @@
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use minijinja::{context, machinery, Environment, Value};
|
||||
use serde_json;
|
||||
|
||||
/// Represents a chat message with role and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
/// Chat template content format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ChatTemplateContentFormat {
|
||||
/// Content is a simple string
|
||||
String,
|
||||
/// Content is a list of structured parts (OpenAI format)
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
impl Default for ChatTemplateContentFormat {
|
||||
fn default() -> Self {
|
||||
Self::String
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ChatTemplateContentFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::String => write!(f, "string"),
|
||||
Self::OpenAI => write!(f, "openai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
/// Detect the content format expected by a Jinja2 chat template
|
||||
///
|
||||
/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
|
||||
/// which uses AST parsing to look for content iteration patterns.
|
||||
///
|
||||
/// Returns:
|
||||
/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
|
||||
/// - ChatTemplateContentFormat::String if template expects simple string content
|
||||
pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
|
||||
// Use AST-based detection (enabled by default)
|
||||
if let Some(format) = detect_format_with_ast(template) {
|
||||
return format;
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
// Default to string format if AST parsing fails
|
||||
ChatTemplateContentFormat::String
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
/// AST-based detection using minijinja's unstable machinery
|
||||
/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions
|
||||
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
|
||||
use minijinja::machinery::{parse, WhitespaceConfig};
|
||||
use minijinja::syntax::SyntaxConfig;
|
||||
|
||||
// Parse the template into AST
|
||||
let ast = match parse(
|
||||
template,
|
||||
"template",
|
||||
SyntaxConfig {},
|
||||
WhitespaceConfig::default(),
|
||||
) {
|
||||
Ok(ast) => ast,
|
||||
Err(_) => return Some(ChatTemplateContentFormat::String),
|
||||
};
|
||||
|
||||
// Traverse AST looking for patterns that indicate OpenAI format
|
||||
let has_iteration = find_content_iteration_in_ast(&ast);
|
||||
let has_structure_checks = find_content_structure_checks_in_ast(&ast);
|
||||
let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast);
|
||||
|
||||
if has_iteration || has_structure_checks || has_assignment_patterns {
|
||||
Some(ChatTemplateContentFormat::OpenAI)
|
||||
} else {
|
||||
Some(ChatTemplateContentFormat::String)
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2
|
||||
/// Find content iteration patterns in AST
|
||||
/// Implements the same logic as SGLang's AST traversal
|
||||
fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => {
|
||||
// Recursively check all children
|
||||
template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_content_iteration_in_ast(child))
|
||||
}
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if this for-loop iterates over message content
|
||||
is_var_or_elems_access(&for_loop.iter, "message", "content") ||
|
||||
is_var_or_elems_access(&for_loop.iter, "msg", "content") ||
|
||||
is_var_or_elems_access(&for_loop.iter, "m", "content") ||
|
||||
// Also check the body for nested loops
|
||||
for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
}
|
||||
Stmt::IfCond(if_cond) => {
|
||||
// Check true and false branches
|
||||
if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_iteration_in_ast(stmt))
|
||||
}
|
||||
_ => false, // Other statement types don't contain loops
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression accesses varname['key'] or varname.key
|
||||
/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes
|
||||
fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool {
|
||||
use machinery::ast::Expr;
|
||||
|
||||
match expr {
|
||||
// Check for attribute access: varname.key
|
||||
Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key,
|
||||
// Check for item access: varname['key'] or varname["key"]
|
||||
Expr::GetItem(getitem) => {
|
||||
is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key)
|
||||
}
|
||||
// Handle filters and tests that might wrap the access
|
||||
Expr::Filter(filter) => {
|
||||
if let Some(ref expr) = filter.expr {
|
||||
is_var_or_elems_access(expr, varname, key)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression is a variable access (like {{ varname }})
|
||||
/// Implements SGLang's _is_var_access logic
|
||||
fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname)
|
||||
}
|
||||
|
||||
/// Check if expression is a constant string with the given value
|
||||
fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Const(const_expr)
|
||||
if const_expr.value.as_str() == Some(value))
|
||||
}
|
||||
|
||||
/// Find content structure checks in AST (like content[0], content|length)
|
||||
fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_content_structure_checks_in_ast(child)),
|
||||
Stmt::ForLoop(for_loop) => for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt)),
|
||||
Stmt::IfCond(if_cond) => {
|
||||
// Check if condition has content structure checks
|
||||
has_content_structure_check_expr(&if_cond.expr)
|
||||
|| if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|
||||
}
|
||||
Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find variable assignment patterns like set content = message['content']
|
||||
fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match ast {
|
||||
Stmt::Template(template) => template
|
||||
.children
|
||||
.iter()
|
||||
.any(|child| find_variable_assignment_patterns_in_ast(child)),
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if this for-loop body contains both assignment and iteration
|
||||
let has_assignment = for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| is_content_assignment_stmt(stmt));
|
||||
let has_iteration = for_loop.body.iter().any(|stmt| {
|
||||
is_content_variable_iteration(stmt)
|
||||
|| matches!(stmt, Stmt::IfCond(if_cond) if
|
||||
if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) ||
|
||||
if_cond.false_body.iter().any(|s| is_content_variable_iteration(s))
|
||||
)
|
||||
});
|
||||
|
||||
(has_assignment && has_iteration)
|
||||
|| for_loop
|
||||
.body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
}
|
||||
Stmt::IfCond(if_cond) => {
|
||||
if_cond
|
||||
.true_body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
|| if_cond
|
||||
.false_body
|
||||
.iter()
|
||||
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression has content structure checks (index access, length, etc.)
|
||||
fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool {
|
||||
use machinery::ast::Expr;
|
||||
|
||||
match expr {
|
||||
// Check for content[0] - index access
|
||||
Expr::GetItem(getitem) => {
|
||||
is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr)
|
||||
}
|
||||
// Check for content|length - filter with length
|
||||
Expr::Filter(filter) => {
|
||||
if let Some(ref filter_expr) = filter.expr {
|
||||
is_content_access(filter_expr) && filter.name == "length"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
// Check for content is sequence/iterable
|
||||
Expr::Test(test) => {
|
||||
is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if statement assigns message content to a variable
|
||||
fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::Stmt;
|
||||
|
||||
match stmt {
|
||||
Stmt::Set(set_stmt) => {
|
||||
// Check if this is setting content = message['content']
|
||||
is_var_access(&set_stmt.target, "content")
|
||||
&& is_var_or_elems_access(&set_stmt.expr, "message", "content")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if statement iterates over content variable
|
||||
fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool {
|
||||
use machinery::ast::{Expr, Stmt};
|
||||
|
||||
match stmt {
|
||||
Stmt::ForLoop(for_loop) => {
|
||||
// Check if iterating over a variable named "content"
|
||||
matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if expression accesses content (message.content, message['content'], etc.)
|
||||
fn is_content_access(expr: &machinery::ast::Expr) -> bool {
|
||||
is_var_or_elems_access(expr, "message", "content")
|
||||
|| is_var_or_elems_access(expr, "msg", "content")
|
||||
|| is_var_or_elems_access(expr, "m", "content")
|
||||
}
|
||||
|
||||
/// Check if expression is a numeric constant (for index access)
|
||||
fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
|
||||
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
@@ -57,9 +309,10 @@ impl ChatTemplateProcessor {
|
||||
///
|
||||
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||
/// but returns the formatted string instead of token IDs.
|
||||
/// Messages should be pre-processed into the format expected by the template.
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
@@ -73,21 +326,13 @@ impl ChatTemplateProcessor {
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert messages to a format Jinja can work with
|
||||
let messages_value: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
context! {
|
||||
role => msg.role.clone(),
|
||||
content => msg.content.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
|
||||
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
|
||||
|
||||
// Render the template
|
||||
// Render the template directly with the provided values
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => messages_value,
|
||||
messages => minijinja_messages,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
@@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
// Simple template that formats messages
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Error, Result};
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
@@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer {
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
chat_template: Option<String>,
|
||||
/// Detected chat template content format (computed once at initialization)
|
||||
content_format: ChatTemplateContentFormat,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
@@ -49,12 +54,20 @@ impl HuggingFaceTokenizer {
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
// Detect content format once at initialization
|
||||
let content_format = if let Some(ref template) = chat_template {
|
||||
detect_chat_template_content_format(template)
|
||||
} else {
|
||||
ChatTemplateContentFormat::String // Default if no template
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template,
|
||||
content_format,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,6 +86,7 @@ impl HuggingFaceTokenizer {
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
chat_template: None,
|
||||
content_format: ChatTemplateContentFormat::String, // Default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,13 +149,22 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
/// Set or override the chat template
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
// Detect format for the new template
|
||||
self.content_format = detect_chat_template_content_format(&template);
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Get the content format expected by the chat template
|
||||
pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
|
||||
self.content_format
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
///
|
||||
/// Takes transformed JSON Values (already transformed based on content format)
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
messages: &[serde_json::Value],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
@@ -150,17 +173,15 @@ impl HuggingFaceTokenizer {
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
Err(Error::msg(
|
||||
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
|
||||
argument was passed! For information about writing templates and setting the \
|
||||
tokenizer.chat_template attribute, please see the documentation at \
|
||||
https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ChatMessage;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
// Note: Actual tokenizer tests would require a real tokenizer file
|
||||
// These would be integration tests rather than unit tests
|
||||
}
|
||||
|
||||
@@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
|
||||
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
pub use chat_template::ChatMessage;
|
||||
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
|
||||
238
sgl-router/tests/chat_template_format_detection.rs
Normal file
238
sgl-router/tests/chat_template_format_detection.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_detect_string_format_deepseek() {
|
||||
// DeepSeek style template - expects string content
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_llama4() {
|
||||
// Llama4 style template - expects structured content
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message['content'] is iterable %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'text' %}
|
||||
{{ content['text'] }}
|
||||
{%- elif content['type'] == 'image' %}
|
||||
<image>
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- else %}
|
||||
{{ message['content'] }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_dot_notation() {
|
||||
// Template using dot notation
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- for part in message.content %}
|
||||
{%- if part.type == 'text' %}
|
||||
{{ part.text }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_variable_assignment() {
|
||||
// Template that assigns content to variable then iterates
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- set content = message['content'] %}
|
||||
{%- if content is sequence %}
|
||||
{%- for item in content %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_glm4v_style() {
|
||||
// GLM4V uses 'msg' instead of 'message'
|
||||
let template = r#"
|
||||
{%- for msg in messages %}
|
||||
{%- for part in msg.content %}
|
||||
{%- if part.type == 'text' %}{{ part.text }}{%- endif %}
|
||||
{%- if part.type == 'image' %}<image>{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_with_length_check() {
|
||||
// Template that checks content length
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message.content|length > 0 %}
|
||||
{%- for item in message.content %}
|
||||
{{ item.text }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_with_index_access() {
|
||||
// Template that accesses content by index
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message.content[0] %}
|
||||
First item: {{ message.content[0].text }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_template_defaults_to_string() {
|
||||
let template = "Not a valid {% jinja template";
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_template_defaults_to_string() {
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(""),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template_unit_test() {
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt %}
|
||||
assistant:
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are helpful".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, true)
|
||||
.unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens_unit_test() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, false)
|
||||
.unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
314
sgl-router/tests/chat_template_integration.rs
Normal file
314
sgl-router/tests/chat_template_integration.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
<|{{ message.role }}|>{{ message.content }}<|end|>
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt %}
|
||||
<|assistant|>
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, true)
|
||||
.unwrap();
|
||||
assert!(result.contains("<|user|>Test<|end|>"));
|
||||
assert!(result.contains("<|assistant|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, false)
|
||||
.unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("What is 2+2?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("How are you?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages_template() {
|
||||
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages: Vec<serde_json::Value> = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_format_detection() {
|
||||
// Test string format detection
|
||||
let string_template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{%- endfor -%}
|
||||
"#;
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(string_template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
|
||||
// Test OpenAI format detection
|
||||
let openai_template = r#"
|
||||
{%- for message in messages -%}
|
||||
{%- for content in message.content -%}
|
||||
{{ content.type }}: {{ content.text }}
|
||||
{%- endfor -%}
|
||||
{%- endfor -%}
|
||||
"#;
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(openai_template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_with_multimodal_content() {
|
||||
// Test that multimodal messages work correctly when serialized to JSON
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{ message.role }}:
|
||||
{%- if message.content is string %}
|
||||
{{ message.content }}
|
||||
{%- else %}
|
||||
{%- for part in message.content %}
|
||||
{%- if part.type == "text" %}
|
||||
{{ part.text }}
|
||||
{%- elif part.type == "image_url" %}
|
||||
[IMAGE]
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% endfor %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Parts(vec![
|
||||
spec::ContentPart::Text {
|
||||
text: "Look at this:".to_string(),
|
||||
},
|
||||
spec::ContentPart::ImageUrl {
|
||||
image_url: spec::ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
// Should contain both text and image parts
|
||||
assert!(result.contains("user:"));
|
||||
assert!(result.contains("Look at this:"));
|
||||
assert!(result.contains("[IMAGE]"));
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_load_chat_template_from_file() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let template_path = temp_dir.path().join("template.jinja");
|
||||
@@ -59,11 +58,28 @@ mod tests {
|
||||
|
||||
// Test that the custom template is used
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there"),
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
|
||||
// Convert to JSON values like the router does
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Verify the custom template format
|
||||
assert!(result.contains("<|user|>Hello"));
|
||||
@@ -73,9 +89,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_override_existing_template() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
@@ -124,8 +137,21 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
// Should use CUSTOM template, not built-in
|
||||
assert!(result.starts_with("CUSTOM:"));
|
||||
@@ -135,9 +161,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_set_chat_template_after_creation() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory and tokenizer file
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tokenizer_json = r#"{
|
||||
@@ -173,8 +196,31 @@ mod tests {
|
||||
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||
tokenizer.set_chat_template(new_template.to_string());
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("World".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("NEW:"));
|
||||
assert!(result.contains("user: Hello;"));
|
||||
@@ -1,150 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_helpers() {
|
||||
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(system_msg.role, "system");
|
||||
assert_eq!(system_msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
assert_eq!(user_msg.content, "Hello!");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
assert_eq!(assistant_msg.content, "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant"),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("How are you?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_with_special_tokens() {
|
||||
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "<s>Hello</s>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let template =
|
||||
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
// Integration test with actual tokenizer file loading would go here
|
||||
// but requires a real tokenizer_config.json file
|
||||
}
|
||||
Reference in New Issue
Block a user