From 9209b209be1ffed21300621c5645abb693ef3b46 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Wed, 24 Sep 2025 11:45:01 -0700 Subject: [PATCH] router-grpc: Support jinja chat template content format detection (#10832) --- sgl-router/Cargo.toml | 2 +- sgl-router/src/routers/grpc/router.rs | 402 +++++++++++++++--- sgl-router/src/tokenizer/chat_template.rs | 377 +++++++++++----- sgl-router/src/tokenizer/huggingface.rs | 66 +-- sgl-router/src/tokenizer/mod.rs | 2 - .../tests/chat_template_format_detection.rs | 238 +++++++++++ sgl-router/tests/chat_template_integration.rs | 314 ++++++++++++++ ...te_loading.rs => chat_template_loading.rs} | 78 +++- sgl-router/tests/test_chat_template.rs | 150 ------- 9 files changed, 1276 insertions(+), 353 deletions(-) create mode 100644 sgl-router/tests/chat_template_format_detection.rs create mode 100644 sgl-router/tests/chat_template_integration.rs rename sgl-router/tests/{test_chat_template_loading.rs => chat_template_loading.rs} (70%) delete mode 100644 sgl-router/tests/test_chat_template.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 9d24cbabc..5fc9c1a8d 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -57,7 +57,7 @@ tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" tokenizers = { version = "0.22.0" } tiktoken-rs = { version = "0.7.0" } -minijinja = { version = "2.0" } +minijinja = { version = "2.0", features = ["unstable_machinery"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } hf-hub = { version = "0.4.3", features = ["tokio"] } rmcp = { version = "0.6.3", features = ["client", "server", diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index e0efd3e8c..529d40d16 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -1,5 +1,18 @@ // gRPC Router Implementation +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{ + body::Body, + extract::Request, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use tracing::{debug, error, info, warn}; + use crate::config::types::RetryConfig; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, @@ -7,27 +20,16 @@ use crate::core::{ use crate::grpc::{proto, SglangSchedulerClient}; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; -use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, ContentPart, ResponseFormat, StringOrArray, - UserMessageContent, -}; +use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat, StringOrArray}; use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; -use crate::tokenizer::{chat_template::ChatMessage as TokenizerChatMessage, traits::Tokenizer}; +use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserRegistry; -use async_trait::async_trait; -use axum::{ - body::Body, - extract::Request, - http::{HeaderMap, StatusCode}, - response::{IntoResponse, Response}, -}; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; -use tracing::{debug, error, info, warn}; use uuid::Uuid; +use crate::tokenizer::chat_template::ChatTemplateContentFormat; +use serde_json::Value; + // Data structures for processing #[derive(Debug)] pub struct ProcessedMessages { @@ -290,16 +292,19 @@ impl GrpcRouter { &self, request: &ChatCompletionRequest, ) -> Result { - let tokenizer_messages = self.convert_messages_for_tokenizer(&request.messages)?; - // Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC let formatted_text = if let Some(hf_tokenizer) = self.tokenizer .as_any() .downcast_ref::() { + // Get content format and transform messages accordingly + let content_format = hf_tokenizer.chat_template_content_format(); + let transformed_messages = + Self::transform_messages_for_content_format(&request.messages, content_format)?; + hf_tokenizer - .apply_chat_template(&tokenizer_messages, true) + .apply_chat_template(&transformed_messages, true) .map_err(|e| format!("Failed to apply chat template: {}", e))? } else { return Err( @@ -317,46 +322,76 @@ impl GrpcRouter { }) } - /// Convert spec ChatMessage enum to tokenizer ChatMessage struct - fn convert_messages_for_tokenizer( - &self, - messages: &[ChatMessage], - ) -> Result, String> { - let mut converted = Vec::new(); + /// Transform messages based on content format for ANY message type + fn transform_messages_for_content_format( + messages: &[crate::protocols::spec::ChatMessage], + content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat, + ) -> Result, String> { + messages + .iter() + .map(|message| { + let mut message_json = serde_json::to_value(message) + .map_err(|e| format!("Failed to serialize message: {}", e))?; - for message in messages { - let tokenizer_msg = match message { - ChatMessage::System { content, .. } => TokenizerChatMessage::new("system", content), - ChatMessage::User { content, .. } => { - let text_content = match content { - UserMessageContent::Text(text) => text.clone(), - UserMessageContent::Parts(parts) => { - // Simple text extraction for now - multimodal is placeholder - parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.as_str()), - ContentPart::ImageUrl { .. } => None, // Skip images for now - }) - .collect::>() - .join(" ") - } - }; - TokenizerChatMessage::new("user", text_content) + if let Some(obj) = message_json.as_object_mut() { + if let Some(content_value) = obj.get_mut("content") { + Self::transform_content_field(content_value, content_format); + } } - ChatMessage::Assistant { content, .. } => { - // Simple content extraction - no special tool/reasoning formatting - TokenizerChatMessage::new("assistant", content.as_deref().unwrap_or("")) + + Ok(message_json) + }) + .collect() + } + + /// Transform a single content field based on content format + fn transform_content_field( + content_value: &mut Value, + content_format: ChatTemplateContentFormat, + ) { + let Some(content_array) = content_value.as_array() else { + return; // Not multimodal, keep as-is + }; + + match content_format { + ChatTemplateContentFormat::String => { + // Extract and join text parts only + let text_parts: Vec = content_array + .iter() + .filter_map(|part| { + part.as_object()? + .get("type")? + .as_str() + .filter(|&t| t == "text") + .and_then(|_| part.as_object()?.get("text")?.as_str()) + .map(String::from) + }) + .collect(); + + if !text_parts.is_empty() { + *content_value = Value::String(text_parts.join(" ")); } - ChatMessage::Tool { content, .. } => TokenizerChatMessage::new("tool", content), - ChatMessage::Function { content, .. } => { - TokenizerChatMessage::new("function", content) - } - }; - converted.push(tokenizer_msg); + } + ChatTemplateContentFormat::OpenAI => { + // Replace media URLs with simple type placeholders + let processed_parts: Vec = content_array + .iter() + .map(|part| { + part.as_object() + .and_then(|obj| obj.get("type")?.as_str()) + .and_then(|type_str| match type_str { + "image_url" => Some(serde_json::json!({"type": "image"})), + "video_url" => Some(serde_json::json!({"type": "video"})), + "audio_url" => Some(serde_json::json!({"type": "audio"})), + _ => None, + }) + .unwrap_or_else(|| part.clone()) + }) + .collect(); + + *content_value = Value::Array(processed_parts); + } } - - Ok(converted) } /// Build gRPC SamplingParams from OpenAI request @@ -636,3 +671,260 @@ impl RouterTrait for GrpcRouter { (StatusCode::SERVICE_UNAVAILABLE).into_response() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent}; + use crate::tokenizer::chat_template::ChatTemplateContentFormat; + use serde_json::json; + + #[test] + fn test_transform_messages_string_format() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "Hello".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }, + ContentPart::Text { + text: "World".to_string(), + }, + ]), + name: None, + }]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should flatten multimodal content to text only + assert_eq!( + transformed_message["content"].as_str().unwrap(), + "Hello World" + ); + assert_eq!(transformed_message["role"].as_str().unwrap(), "user"); + } + + #[test] + fn test_transform_messages_openai_format() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "Describe this image:".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("high".to_string()), + }, + }, + ]), + name: None, + }]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::OpenAI, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should replace media URLs with simple type placeholders + let content_array = transformed_message["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 2); + + // Text part should remain unchanged + assert_eq!(content_array[0]["type"], "text"); + assert_eq!(content_array[0]["text"], "Describe this image:"); + + // Image part should be replaced with simple type placeholder + assert_eq!(content_array[1], json!({"type": "image"})); + } + + #[test] + fn test_transform_messages_simple_string_content() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Simple text message".to_string()), + name: None, + }]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Simple string content should remain unchanged + assert_eq!( + transformed_message["content"].as_str().unwrap(), + "Simple text message" + ); + } + + #[test] + fn test_transform_messages_assistant_message() { + let messages = vec![ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("Assistant response".to_string()), + name: None, + tool_calls: None, + function_call: None, + reasoning_content: None, + }]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + assert_eq!(transformed_message["role"].as_str().unwrap(), "assistant"); + assert_eq!( + transformed_message["content"].as_str().unwrap(), + "Assistant response" + ); + } + + #[test] + fn test_transform_messages_multiple_messages() { + let messages = vec![ + ChatMessage::System { + role: "system".to_string(), + content: "System prompt".to_string(), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "User message".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }, + ]), + name: None, + }, + ]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result.len(), 2); + + // System message should remain unchanged + assert_eq!(result[0]["role"].as_str().unwrap(), "system"); + assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt"); + + // User message should be flattened to text only + assert_eq!(result[1]["role"].as_str().unwrap(), "user"); + assert_eq!(result[1]["content"].as_str().unwrap(), "User message"); + } + + #[test] + fn test_transform_messages_empty_text_parts() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }]), + name: None, + }]; + + let result = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should keep original multimodal content when no text parts exist + assert!(transformed_message["content"].is_array()); + } + + #[test] + fn test_transform_messages_mixed_content_types() { + // Test with both text and multimodal content + let messages = vec![ + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Plain text".to_string()), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "With image".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("low".to_string()), + }, + }, + ]), + name: None, + }, + ]; + + // Test String format + let result_string = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::String, + ) + .unwrap(); + + assert_eq!(result_string.len(), 2); + assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text"); + assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); + + // Test OpenAI format + let result_openai = GrpcRouter::transform_messages_for_content_format( + &messages, + ChatTemplateContentFormat::OpenAI, + ) + .unwrap(); + + assert_eq!(result_openai.len(), 2); + assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text"); + + let content_array = result_openai[1]["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 2); + assert_eq!(content_array[0]["type"], "text"); + assert_eq!(content_array[1], json!({"type": "image"})); + } +} diff --git a/sgl-router/src/tokenizer/chat_template.rs b/sgl-router/src/tokenizer/chat_template.rs index 8a9a0fe1d..798ede015 100644 --- a/sgl-router/src/tokenizer/chat_template.rs +++ b/sgl-router/src/tokenizer/chat_template.rs @@ -4,39 +4,291 @@ //! similar to HuggingFace transformers' apply_chat_template method. use anyhow::{anyhow, Result}; -use minijinja::{context, Environment, Value}; -use serde::{Deserialize, Serialize}; +use minijinja::{context, machinery, Environment, Value}; use serde_json; -/// Represents a chat message with role and content -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessage { - pub role: String, - pub content: String, +/// Chat template content format +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChatTemplateContentFormat { + /// Content is a simple string + String, + /// Content is a list of structured parts (OpenAI format) + OpenAI, } -impl ChatMessage { - pub fn new(role: impl Into, content: impl Into) -> Self { - ChatMessage { - role: role.into(), - content: content.into(), +impl Default for ChatTemplateContentFormat { + fn default() -> Self { + Self::String + } +} + +impl std::fmt::Display for ChatTemplateContentFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::String => write!(f, "string"), + Self::OpenAI => write!(f, "openai"), } } +} - pub fn system(content: impl Into) -> Self { - Self::new("system", content) +/// Detect the content format expected by a Jinja2 chat template +/// +/// This implements the same detection logic as SGLang's detect_jinja_template_content_format +/// which uses AST parsing to look for content iteration patterns. +/// +/// Returns: +/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts) +/// - ChatTemplateContentFormat::String if template expects simple string content +pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat { + // Use AST-based detection (enabled by default) + if let Some(format) = detect_format_with_ast(template) { + return format; } - pub fn user(content: impl Into) -> Self { - Self::new("user", content) - } + // Default to string format if AST parsing fails + ChatTemplateContentFormat::String +} - pub fn assistant(content: impl Into) -> Self { - Self::new("assistant", content) +/// AST-based detection using minijinja's unstable machinery +/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions +fn detect_format_with_ast(template: &str) -> Option { + use minijinja::machinery::{parse, WhitespaceConfig}; + use minijinja::syntax::SyntaxConfig; + + // Parse the template into AST + let ast = match parse( + template, + "template", + SyntaxConfig {}, + WhitespaceConfig::default(), + ) { + Ok(ast) => ast, + Err(_) => return Some(ChatTemplateContentFormat::String), + }; + + // Traverse AST looking for patterns that indicate OpenAI format + let has_iteration = find_content_iteration_in_ast(&ast); + let has_structure_checks = find_content_structure_checks_in_ast(&ast); + let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast); + + if has_iteration || has_structure_checks || has_assignment_patterns { + Some(ChatTemplateContentFormat::OpenAI) + } else { + Some(ChatTemplateContentFormat::String) } } -/// Chat template processor using Jinja2 +/// Find content iteration patterns in AST +/// Implements the same logic as SGLang's AST traversal +fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool { + use machinery::ast::Stmt; + + match ast { + Stmt::Template(template) => { + // Recursively check all children + template + .children + .iter() + .any(|child| find_content_iteration_in_ast(child)) + } + Stmt::ForLoop(for_loop) => { + // Check if this for-loop iterates over message content + is_var_or_elems_access(&for_loop.iter, "message", "content") || + is_var_or_elems_access(&for_loop.iter, "msg", "content") || + is_var_or_elems_access(&for_loop.iter, "m", "content") || + // Also check the body for nested loops + for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt)) + } + Stmt::IfCond(if_cond) => { + // Check true and false branches + if_cond + .true_body + .iter() + .any(|stmt| find_content_iteration_in_ast(stmt)) + || if_cond + .false_body + .iter() + .any(|stmt| find_content_iteration_in_ast(stmt)) + } + _ => false, // Other statement types don't contain loops + } +} + +/// Check if expression accesses varname['key'] or varname.key +/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes +fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool { + use machinery::ast::Expr; + + match expr { + // Check for attribute access: varname.key + Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key, + // Check for item access: varname['key'] or varname["key"] + Expr::GetItem(getitem) => { + is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key) + } + // Handle filters and tests that might wrap the access + Expr::Filter(filter) => { + if let Some(ref expr) = filter.expr { + is_var_or_elems_access(expr, varname, key) + } else { + false + } + } + Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key), + _ => false, + } +} + +/// Check if expression is a variable access (like {{ varname }}) +/// Implements SGLang's _is_var_access logic +fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool { + matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname) +} + +/// Check if expression is a constant string with the given value +fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool { + matches!(expr, machinery::ast::Expr::Const(const_expr) + if const_expr.value.as_str() == Some(value)) +} + +/// Find content structure checks in AST (like content[0], content|length) +fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool { + use machinery::ast::Stmt; + + match ast { + Stmt::Template(template) => template + .children + .iter() + .any(|child| find_content_structure_checks_in_ast(child)), + Stmt::ForLoop(for_loop) => for_loop + .body + .iter() + .any(|stmt| find_content_structure_checks_in_ast(stmt)), + Stmt::IfCond(if_cond) => { + // Check if condition has content structure checks + has_content_structure_check_expr(&if_cond.expr) + || if_cond + .true_body + .iter() + .any(|stmt| find_content_structure_checks_in_ast(stmt)) + || if_cond + .false_body + .iter() + .any(|stmt| find_content_structure_checks_in_ast(stmt)) + } + Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr), + _ => false, + } +} + +/// Find variable assignment patterns like set content = message['content'] +fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool { + use machinery::ast::Stmt; + + match ast { + Stmt::Template(template) => template + .children + .iter() + .any(|child| find_variable_assignment_patterns_in_ast(child)), + Stmt::ForLoop(for_loop) => { + // Check if this for-loop body contains both assignment and iteration + let has_assignment = for_loop + .body + .iter() + .any(|stmt| is_content_assignment_stmt(stmt)); + let has_iteration = for_loop.body.iter().any(|stmt| { + is_content_variable_iteration(stmt) + || matches!(stmt, Stmt::IfCond(if_cond) if + if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) || + if_cond.false_body.iter().any(|s| is_content_variable_iteration(s)) + ) + }); + + (has_assignment && has_iteration) + || for_loop + .body + .iter() + .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) + } + Stmt::IfCond(if_cond) => { + if_cond + .true_body + .iter() + .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) + || if_cond + .false_body + .iter() + .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) + } + _ => false, + } +} + +/// Check if expression has content structure checks (index access, length, etc.) +fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool { + use machinery::ast::Expr; + + match expr { + // Check for content[0] - index access + Expr::GetItem(getitem) => { + is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr) + } + // Check for content|length - filter with length + Expr::Filter(filter) => { + if let Some(ref filter_expr) = filter.expr { + is_content_access(filter_expr) && filter.name == "length" + } else { + false + } + } + // Check for content is sequence/iterable + Expr::Test(test) => { + is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable") + } + _ => false, + } +} + +/// Check if statement assigns message content to a variable +fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool { + use machinery::ast::Stmt; + + match stmt { + Stmt::Set(set_stmt) => { + // Check if this is setting content = message['content'] + is_var_access(&set_stmt.target, "content") + && is_var_or_elems_access(&set_stmt.expr, "message", "content") + } + _ => false, + } +} + +/// Check if statement iterates over content variable +fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool { + use machinery::ast::{Expr, Stmt}; + + match stmt { + Stmt::ForLoop(for_loop) => { + // Check if iterating over a variable named "content" + matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content") + } + _ => false, + } +} + +/// Check if expression accesses content (message.content, message['content'], etc.) +fn is_content_access(expr: &machinery::ast::Expr) -> bool { + is_var_or_elems_access(expr, "message", "content") + || is_var_or_elems_access(expr, "msg", "content") + || is_var_or_elems_access(expr, "m", "content") +} + +/// Check if expression is a numeric constant (for index access) +fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool { + matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number()) +} + +/// Chat template processor using Jinja2 - simple wrapper like HuggingFace pub struct ChatTemplateProcessor { template: String, bos_token: Option, @@ -57,9 +309,10 @@ impl ChatTemplateProcessor { /// /// This mimics the behavior of HuggingFace's apply_chat_template method /// but returns the formatted string instead of token IDs. + /// Messages should be pre-processed into the format expected by the template. pub fn apply_chat_template( &self, - messages: &[ChatMessage], + messages: &[serde_json::Value], add_generation_prompt: bool, ) -> Result { let mut env = Environment::new(); @@ -73,21 +326,13 @@ impl ChatTemplateProcessor { .get_template("chat") .map_err(|e| anyhow!("Failed to get template: {}", e))?; - // Convert messages to a format Jinja can work with - let messages_value: Vec = messages - .iter() - .map(|msg| { - context! { - role => msg.role.clone(), - content => msg.content.clone() - } - }) - .collect(); + // Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump() + let minijinja_messages: Vec = messages.iter().map(Value::from_serialize).collect(); - // Render the template + // Render the template directly with the provided values let rendered = tmpl .render(context! { - messages => messages_value, + messages => minijinja_messages, add_generation_prompt => add_generation_prompt, bos_token => self.bos_token.clone().unwrap_or_default(), eos_token => self.eos_token.clone().unwrap_or_default() @@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result".to_string()), - Some("".to_string()), - ); - - let messages = vec![ChatMessage::user("Test")]; - - let result = processor.apply_chat_template(&messages, false).unwrap(); - assert!(result.contains("")); - assert!(result.contains("")); - } -} diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index 7cb930d18..f4d926621 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -1,11 +1,14 @@ +use std::collections::HashMap; + +use anyhow::{Error, Result}; +use tokenizers::tokenizer::Tokenizer as HfTokenizer; + +use super::chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, +}; use super::traits::{ Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, }; -use anyhow::{Error, Result}; -use std::collections::HashMap; -use tokenizers::tokenizer::Tokenizer as HfTokenizer; - -use super::chat_template::{ChatMessage, ChatTemplateProcessor}; /// HuggingFace tokenizer wrapper pub struct HuggingFaceTokenizer { @@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer { vocab: HashMap, reverse_vocab: HashMap, chat_template: Option, + /// Detected chat template content format (computed once at initialization) + content_format: ChatTemplateContentFormat, } impl HuggingFaceTokenizer { @@ -49,12 +54,20 @@ impl HuggingFaceTokenizer { Self::load_chat_template(file_path) }; + // Detect content format once at initialization + let content_format = if let Some(ref template) = chat_template { + detect_chat_template_content_format(template) + } else { + ChatTemplateContentFormat::String // Default if no template + }; + Ok(HuggingFaceTokenizer { tokenizer, special_tokens, vocab, reverse_vocab, chat_template, + content_format, }) } @@ -73,6 +86,7 @@ impl HuggingFaceTokenizer { vocab, reverse_vocab, chat_template: None, + content_format: ChatTemplateContentFormat::String, // Default } } @@ -135,13 +149,22 @@ impl HuggingFaceTokenizer { /// Set or override the chat template pub fn set_chat_template(&mut self, template: String) { + // Detect format for the new template + self.content_format = detect_chat_template_content_format(&template); self.chat_template = Some(template); } + /// Get the content format expected by the chat template + pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat { + self.content_format + } + /// Apply chat template if available + /// + /// Takes transformed JSON Values (already transformed based on content format) pub fn apply_chat_template( &self, - messages: &[ChatMessage], + messages: &[serde_json::Value], add_generation_prompt: bool, ) -> Result { if let Some(ref template) = self.chat_template { @@ -150,17 +173,15 @@ impl HuggingFaceTokenizer { self.special_tokens.bos_token.clone(), self.special_tokens.eos_token.clone(), ); + processor.apply_chat_template(messages, add_generation_prompt) } else { - // Fallback to simple formatting if no template is available - let mut result = String::new(); - for msg in messages { - result.push_str(&format!("{}: {}\n", msg.role, msg.content)); - } - if add_generation_prompt { - result.push_str("assistant: "); - } - Ok(result) + Err(Error::msg( + "Cannot use chat template functions because tokenizer.chat_template is not set and no template \ + argument was passed! For information about writing templates and setting the \ + tokenizer.chat_template attribute, please see the documentation at \ + https://huggingface.co/docs/transformers/main/en/chat_templating" + )) } } } @@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer { #[cfg(test)] mod tests { - use super::ChatMessage; - - #[test] - fn test_chat_message_creation() { - let msg = ChatMessage::system("You are a helpful assistant"); - assert_eq!(msg.role, "system"); - assert_eq!(msg.content, "You are a helpful assistant"); - - let user_msg = ChatMessage::user("Hello!"); - assert_eq!(user_msg.role, "user"); - - let assistant_msg = ChatMessage::assistant("Hi there!"); - assert_eq!(assistant_msg.role, "assistant"); - } - // Note: Actual tokenizer tests would require a real tokenizer file // These would be integration tests rather than unit tests } diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index 98a23f761..5ff4cdbf1 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz pub use huggingface::HuggingFaceTokenizer; -pub use chat_template::ChatMessage; - pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs new file mode 100644 index 000000000..cfca6ff8e --- /dev/null +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -0,0 +1,238 @@ +use sglang_router_rs::protocols::spec; +use sglang_router_rs::tokenizer::chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, +}; + +#[test] +fn test_detect_string_format_deepseek() { + // DeepSeek style template - expects string content + let template = r#" + {%- for message in messages %} + {%- if message['role'] == 'user' %} + User: {{ message['content'] }} + {%- elif message['role'] == 'assistant' %} + Assistant: {{ message['content'] }} + {%- endif %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::String + ); +} + +#[test] +fn test_detect_openai_format_llama4() { + // Llama4 style template - expects structured content + let template = r#" + {%- for message in messages %} + {%- if message['content'] is iterable %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{ content['text'] }} + {%- elif content['type'] == 'image' %} + + {%- endif %} + {%- endfor %} + {%- else %} + {{ message['content'] }} + {%- endif %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_dot_notation() { + // Template using dot notation + let template = r#" + {%- for message in messages %} + {%- for part in message.content %} + {%- if part.type == 'text' %} + {{ part.text }} + {%- endif %} + {%- endfor %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_variable_assignment() { + // Template that assigns content to variable then iterates + let template = r#" + {%- for message in messages %} + {%- set content = message['content'] %} + {%- if content is sequence %} + {%- for item in content %} + {{ item }} + {%- endfor %} + {%- endif %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_glm4v_style() { + // GLM4V uses 'msg' instead of 'message' + let template = r#" + {%- for msg in messages %} + {%- for part in msg.content %} + {%- if part.type == 'text' %}{{ part.text }}{%- endif %} + {%- if part.type == 'image' %}{%- endif %} + {%- endfor %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_with_length_check() { + // Template that checks content length + let template = r#" + {%- for message in messages %} + {%- if message.content|length > 0 %} + {%- for item in message.content %} + {{ item.text }} + {%- endfor %} + {%- endif %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_with_index_access() { + // Template that accesses content by index + let template = r#" + {%- for message in messages %} + {%- if message.content[0] %} + First item: {{ message.content[0].text }} + {%- endif %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_invalid_template_defaults_to_string() { + let template = "Not a valid {% jinja template"; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::String + ); +} + +#[test] +fn test_empty_template_defaults_to_string() { + assert_eq!( + detect_chat_template_content_format(""), + ChatTemplateContentFormat::String + ); +} + +#[test] +fn test_simple_chat_template_unit_test() { + let template = r#" +{%- for message in messages %} +{{ message.role }}: {{ message.content }} +{% endfor -%} +{%- if add_generation_prompt %} +assistant: +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = vec![ + spec::ChatMessage::System { + role: "system".to_string(), + content: "You are helpful".to_string(), + name: None, + }, + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Hello".to_string()), + name: None, + }, + ]; + + // Convert to JSON values like the router does + let message_values: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor + .apply_chat_template(&message_values, true) + .unwrap(); + assert!(result.contains("system: You are helpful")); + assert!(result.contains("user: Hello")); + assert!(result.contains("assistant:")); +} + +#[test] +fn test_chat_template_with_tokens_unit_test() { + // Template that uses special tokens + let template = r#" +{{ bos_token }} +{%- for message in messages -%} +{{ message.role }}: {{ message.content }}{{ eos_token }} +{% endfor -%} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + // Convert to JSON values like the router does + let message_values: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor + .apply_chat_template(&message_values, false) + .unwrap(); + assert!(result.contains("")); + assert!(result.contains("")); +} diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs new file mode 100644 index 000000000..95cea27c0 --- /dev/null +++ b/sgl-router/tests/chat_template_integration.rs @@ -0,0 +1,314 @@ +use sglang_router_rs::protocols::spec; +use sglang_router_rs::tokenizer::chat_template::{ + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, +}; + +#[test] +fn test_simple_chat_template() { + let template = r#" +{%- for message in messages %} +<|{{ message.role }}|>{{ message.content }}<|end|> +{% endfor -%} +{%- if add_generation_prompt %} +<|assistant|> +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + // Convert to JSON values like the router does + let message_values: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor + .apply_chat_template(&message_values, true) + .unwrap(); + assert!(result.contains("<|user|>Test<|end|>")); + assert!(result.contains("<|assistant|>")); +} + +#[test] +fn test_chat_template_with_tokens() { + // Template that uses special tokens + let template = r#" +{{ bos_token }} +{%- for message in messages -%} +{{ message.role }}: {{ message.content }}{{ eos_token }} +{% endfor -%} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + // Convert to JSON values like the router does + let message_values: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor + .apply_chat_template(&message_values, false) + .unwrap(); + assert!(result.contains("")); + assert!(result.contains("")); +} + +#[test] +fn test_llama_style_template() { + // Test a Llama-style chat template + let template = r#" +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {%- set system_message = '' -%} +{%- endif -%} + +{{- bos_token }} +{%- if system_message %} +{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }} +{%- endif %} + +{%- for message in messages %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("<|begin_of_text|>".to_string()), + Some("<|end_of_text|>".to_string()), + ); + + let messages = vec![ + spec::ChatMessage::System { + role: "system".to_string(), + content: "You are a helpful assistant".to_string(), + name: None, + }, + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("What is 2+2?".to_string()), + name: None, + }, + ]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor.apply_chat_template(&json_messages, true).unwrap(); + + // Check that the result contains expected markers + assert!(result.contains("<|begin_of_text|>")); + assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); + assert!(result.contains("You are a helpful assistant")); + assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); + assert!(result.contains("What is 2+2?")); + assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>")); +} + +#[test] +fn test_chatml_template() { + // Test a ChatML-style template + let template = r#" +{%- for message in messages %} + {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = vec![ + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Hello".to_string()), + name: None, + }, + spec::ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("Hi there!".to_string()), + name: None, + tool_calls: None, + function_call: None, + reasoning_content: None, + }, + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("How are you?".to_string()), + name: None, + }, + ]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor.apply_chat_template(&json_messages, true).unwrap(); + + // Check ChatML format + assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); + assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>")); + assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>")); + assert!(result.ends_with("<|im_start|>assistant\n")); +} + +#[test] +fn test_template_without_generation_prompt() { + let template = r#" +{%- for message in messages -%} +{{ message.role }}: {{ message.content }} +{% endfor -%} +{%- if add_generation_prompt -%} +assistant: +{%- endif -%} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + // Test without generation prompt + let result = processor + .apply_chat_template(&json_messages, false) + .unwrap(); + assert_eq!(result.trim(), "user: Test"); + + // Test with generation prompt + let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap(); + assert!(result_with_prompt.contains("assistant:")); +} + +#[test] +fn test_empty_messages_template() { + let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages: Vec = vec![]; + let result = processor.apply_chat_template(&messages, false).unwrap(); + assert_eq!(result, ""); +} + +#[test] +fn test_content_format_detection() { + // Test string format detection + let string_template = r#" +{%- for message in messages -%} +{{ message.role }}: {{ message.content }} +{%- endfor -%} +"#; + assert_eq!( + detect_chat_template_content_format(string_template), + ChatTemplateContentFormat::String + ); + + // Test OpenAI format detection + let openai_template = r#" +{%- for message in messages -%} + {%- for content in message.content -%} + {{ content.type }}: {{ content.text }} + {%- endfor -%} +{%- endfor -%} +"#; + assert_eq!( + detect_chat_template_content_format(openai_template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_template_with_multimodal_content() { + // Test that multimodal messages work correctly when serialized to JSON + let template = r#" +{%- for message in messages %} +{{ message.role }}: +{%- if message.content is string %} +{{ message.content }} +{%- else %} +{%- for part in message.content %} + {%- if part.type == "text" %} +{{ part.text }} + {%- elif part.type == "image_url" %} +[IMAGE] + {%- endif %} +{%- endfor %} +{%- endif %} +{% endfor %} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Parts(vec![ + spec::ContentPart::Text { + text: "Look at this:".to_string(), + }, + spec::ContentPart::ImageUrl { + image_url: spec::ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }, + ]), + name: None, + }]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = processor + .apply_chat_template(&json_messages, false) + .unwrap(); + + // Should contain both text and image parts + assert!(result.contains("user:")); + assert!(result.contains("Look at this:")); + assert!(result.contains("[IMAGE]")); +} diff --git a/sgl-router/tests/test_chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs similarity index 70% rename from sgl-router/tests/test_chat_template_loading.rs rename to sgl-router/tests/chat_template_loading.rs index ad1501233..7a03337fc 100644 --- a/sgl-router/tests/test_chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -1,13 +1,12 @@ #[cfg(test)] mod tests { + use sglang_router_rs::protocols::spec; + use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use std::fs; use tempfile::TempDir; #[test] fn test_load_chat_template_from_file() { - use sglang_router_rs::tokenizer::chat_template::ChatMessage; - use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; - // Create temporary directory let temp_dir = TempDir::new().unwrap(); let template_path = temp_dir.path().join("template.jinja"); @@ -59,11 +58,28 @@ mod tests { // Test that the custom template is used let messages = vec![ - ChatMessage::user("Hello"), - ChatMessage::assistant("Hi there"), + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Hello".to_string()), + name: None, + }, + spec::ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("Hi there".to_string()), + name: None, + tool_calls: None, + function_call: None, + reasoning_content: None, + }, ]; - let result = tokenizer.apply_chat_template(&messages, true).unwrap(); + // Convert to JSON values like the router does + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = tokenizer.apply_chat_template(&json_messages, true).unwrap(); // Verify the custom template format assert!(result.contains("<|user|>Hello")); @@ -73,9 +89,6 @@ mod tests { #[test] fn test_override_existing_template() { - use sglang_router_rs::tokenizer::chat_template::ChatMessage; - use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; - // Create temporary directory let temp_dir = TempDir::new().unwrap(); @@ -124,8 +137,21 @@ mod tests { ) .unwrap(); - let messages = vec![ChatMessage::user("Test")]; - let result = tokenizer.apply_chat_template(&messages, false).unwrap(); + let messages = [spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = tokenizer + .apply_chat_template(&json_messages, false) + .unwrap(); // Should use CUSTOM template, not built-in assert!(result.starts_with("CUSTOM:")); @@ -135,9 +161,6 @@ mod tests { #[test] fn test_set_chat_template_after_creation() { - use sglang_router_rs::tokenizer::chat_template::ChatMessage; - use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; - // Create temporary directory and tokenizer file let temp_dir = TempDir::new().unwrap(); let tokenizer_json = r#"{ @@ -173,8 +196,31 @@ mod tests { "NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}"; tokenizer.set_chat_template(new_template.to_string()); - let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")]; - let result = tokenizer.apply_chat_template(&messages, false).unwrap(); + let messages = vec![ + spec::ChatMessage::User { + role: "user".to_string(), + content: spec::UserMessageContent::Text("Hello".to_string()), + name: None, + }, + spec::ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("World".to_string()), + name: None, + tool_calls: None, + function_call: None, + reasoning_content: None, + }, + ]; + + // Convert to JSON values + let json_messages: Vec = messages + .iter() + .map(|msg| serde_json::to_value(msg).unwrap()) + .collect(); + + let result = tokenizer + .apply_chat_template(&json_messages, false) + .unwrap(); assert!(result.starts_with("NEW:")); assert!(result.contains("user: Hello;")); diff --git a/sgl-router/tests/test_chat_template.rs b/sgl-router/tests/test_chat_template.rs deleted file mode 100644 index 4a0e73bd0..000000000 --- a/sgl-router/tests/test_chat_template.rs +++ /dev/null @@ -1,150 +0,0 @@ -#[cfg(test)] -mod tests { - use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor}; - - #[test] - fn test_chat_message_helpers() { - let system_msg = ChatMessage::system("You are a helpful assistant"); - assert_eq!(system_msg.role, "system"); - assert_eq!(system_msg.content, "You are a helpful assistant"); - - let user_msg = ChatMessage::user("Hello!"); - assert_eq!(user_msg.role, "user"); - assert_eq!(user_msg.content, "Hello!"); - - let assistant_msg = ChatMessage::assistant("Hi there!"); - assert_eq!(assistant_msg.role, "assistant"); - assert_eq!(assistant_msg.content, "Hi there!"); - } - - #[test] - fn test_llama_style_template() { - // Test a Llama-style chat template - let template = r#" -{%- if messages[0]['role'] == 'system' -%} - {%- set system_message = messages[0]['content'] -%} - {%- set messages = messages[1:] -%} -{%- else -%} - {%- set system_message = '' -%} -{%- endif -%} - -{{- bos_token }} -{%- if system_message %} -{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }} -{%- endif %} - -{%- for message in messages %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} -{%- endfor %} - -{%- if add_generation_prompt %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} -{%- endif %} -"#; - - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("<|begin_of_text|>".to_string()), - Some("<|end_of_text|>".to_string()), - ); - - let messages = vec![ - ChatMessage::system("You are a helpful assistant"), - ChatMessage::user("What is 2+2?"), - ]; - - let result = processor.apply_chat_template(&messages, true).unwrap(); - - // Check that the result contains expected markers - assert!(result.contains("<|begin_of_text|>")); - assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); - assert!(result.contains("You are a helpful assistant")); - assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); - assert!(result.contains("What is 2+2?")); - assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>")); - } - - #[test] - fn test_chatml_template() { - // Test a ChatML-style template - let template = r#" -{%- for message in messages %} - {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} -{%- endif %} -"#; - - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); - - let messages = vec![ - ChatMessage::user("Hello"), - ChatMessage::assistant("Hi there!"), - ChatMessage::user("How are you?"), - ]; - - let result = processor.apply_chat_template(&messages, true).unwrap(); - - // Check ChatML format - assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); - assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>")); - assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>")); - assert!(result.ends_with("<|im_start|>assistant\n")); - } - - #[test] - fn test_template_without_generation_prompt() { - let template = r#" -{%- for message in messages -%} -{{ message.role }}: {{ message.content }} -{% endfor -%} -{%- if add_generation_prompt -%} -assistant: -{%- endif -%} -"#; - - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); - - let messages = vec![ChatMessage::user("Test")]; - - // Test without generation prompt - let result = processor.apply_chat_template(&messages, false).unwrap(); - assert_eq!(result.trim(), "user: Test"); - - // Test with generation prompt - let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap(); - assert!(result_with_prompt.contains("assistant:")); - } - - #[test] - fn test_template_with_special_tokens() { - let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#; - - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("".to_string()), - Some("".to_string()), - ); - - let messages = vec![ChatMessage::user("Hello")]; - - let result = processor.apply_chat_template(&messages, false).unwrap(); - assert_eq!(result, "Hello"); - } - - #[test] - fn test_empty_messages() { - let template = - r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; - - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); - - let messages = vec![]; - let result = processor.apply_chat_template(&messages, false).unwrap(); - assert_eq!(result, ""); - } - - // Integration test with actual tokenizer file loading would go here - // but requires a real tokenizer_config.json file -}