diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 529d40d16..a54f6ff83 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -27,7 +27,7 @@ use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserRegistry; use uuid::Uuid; -use crate::tokenizer::chat_template::ChatTemplateContentFormat; +use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use serde_json::Value; // Data structures for processing @@ -300,12 +300,87 @@ impl GrpcRouter { { // Get content format and transform messages accordingly let content_format = hf_tokenizer.chat_template_content_format(); - let transformed_messages = - Self::transform_messages_for_content_format(&request.messages, content_format)?; + let mut transformed_messages = + Self::process_content_format(&request.messages, content_format)?; - hf_tokenizer - .apply_chat_template(&transformed_messages, true) - .map_err(|e| format!("Failed to apply chat template: {}", e))? + // Process tool call arguments in assistant messages + Self::process_tool_call_arguments(&mut transformed_messages)?; + + // Convert tools to JSON values for template processing + let tools_json: Option> = request + .tools + .as_ref() + .map(|tools| { + tools + .iter() + .map(serde_json::to_value) + .collect::, _>>() + }) + .transpose() + .map_err(|e| format!("Failed to serialize tools: {}", e))?; + + // Build template kwargs, merging reasoning_effort if present + let mut combined_template_kwargs = std::collections::HashMap::new(); + + // Add reasoning_effort if present (like Python does) + if let Some(reasoning_effort) = &request.reasoning_effort { + combined_template_kwargs.insert( + "reasoning_effort".to_string(), + serde_json::Value::String(reasoning_effort.clone()), + ); + } + + // Add any additional template kwargs from request + if let Some(template_kwargs) = &request.chat_template_kwargs { + for (key, value) in template_kwargs { + combined_template_kwargs.insert(key.clone(), value.clone()); + } + } + + let final_template_kwargs = if combined_template_kwargs.is_empty() { + None + } else { + Some(&combined_template_kwargs) + }; + + let params = ChatTemplateParams { + add_generation_prompt: true, + continue_final_message: request.continue_final_message, + tools: tools_json.as_deref(), + template_kwargs: final_template_kwargs, + ..Default::default() + }; + + // Handle assistant prefix for continue_final_message + let assistant_prefix = if request.continue_final_message + && !transformed_messages.is_empty() + && transformed_messages + .last() + .and_then(|msg| msg.get("role")) + .and_then(|v| v.as_str()) + == Some("assistant") + { + // Pop the last message to handle it separately + let last_msg = transformed_messages.pop().unwrap(); + last_msg + .get("content") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + } else { + None + }; + + // Apply chat template with the (now possibly shorter) list of messages + let rendered = hf_tokenizer + .apply_chat_template(&transformed_messages, params) + .map_err(|e| format!("Failed to apply chat template: {}", e))?; + + // Append assistant prefix if we have one + if let Some(prefix) = assistant_prefix { + format!("{}{}", rendered, prefix) + } else { + rendered + } } else { return Err( "gRPC router requires HuggingFace tokenizer with chat template support".to_string(), @@ -322,8 +397,8 @@ impl GrpcRouter { }) } - /// Transform messages based on content format for ANY message type - fn transform_messages_for_content_format( + /// Process messages based on content format for ANY message type + fn process_content_format( messages: &[crate::protocols::spec::ChatMessage], content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat, ) -> Result, String> { @@ -394,6 +469,49 @@ impl GrpcRouter { } } + /// Process tool call arguments in messages + /// Per Transformers docs, tool call arguments in assistant messages should be dicts + fn process_tool_call_arguments(messages: &mut [serde_json::Value]) -> Result<(), String> { + for msg in messages { + // Early return if not assistant message + let role = msg.get("role").and_then(|v| v.as_str()); + if role != Some("assistant") { + continue; + } + + // Early return if no tool_calls + let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) + else { + continue; + }; + + // Process each tool call's arguments + for call in tool_calls { + let Some(function) = call.get_mut("function") else { + continue; + }; + let Some(args) = function.get_mut("arguments") else { + continue; + }; + let Some(args_str) = args.as_str() else { + continue; + }; + + // Parse JSON string to object (like Python json.loads) + match serde_json::from_str::(args_str) { + Ok(parsed) => *args = parsed, + Err(e) => { + return Err(format!( + "Failed to parse tool call arguments as JSON: '{}'. Error: {}", + args_str, e + )) + } + } + } + } + Ok(()) + } + /// Build gRPC SamplingParams from OpenAI request fn build_grpc_sampling_params( &self, @@ -410,6 +528,19 @@ impl GrpcRouter { .or(request.max_tokens) .map(|v| v as i32); + // Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none" + let skip_special_tokens = if request.tools.is_some() { + match &request.tool_choice { + Some(crate::protocols::spec::ToolChoice::Value( + crate::protocols::spec::ToolChoiceValue::None, + )) => request.skip_special_tokens, + Some(_) => false, // tool_choice is not "none" + None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present + } + } else { + request.skip_special_tokens + }; + #[allow(deprecated)] Ok(proto::SamplingParams { temperature: request.temperature.unwrap_or(1.0), @@ -422,7 +553,7 @@ impl GrpcRouter { max_new_tokens, stop: stop_sequences, stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens: request.skip_special_tokens, + skip_special_tokens, n: request.n.unwrap_or(1) as i32, structural_tag: structural_tag.unwrap_or_default(), constraint: self.build_constraint(request)?, @@ -700,11 +831,9 @@ mod tests { name: None, }]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result.len(), 1); let transformed_message = &result[0]; @@ -735,11 +864,9 @@ mod tests { name: None, }]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::OpenAI, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI) + .unwrap(); assert_eq!(result.len(), 1); let transformed_message = &result[0]; @@ -764,11 +891,9 @@ mod tests { name: None, }]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result.len(), 1); let transformed_message = &result[0]; @@ -791,11 +916,9 @@ mod tests { reasoning_content: None, }]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result.len(), 1); let transformed_message = &result[0]; @@ -832,11 +955,9 @@ mod tests { }, ]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result.len(), 2); @@ -862,11 +983,9 @@ mod tests { name: None, }]; - let result = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result.len(), 1); let transformed_message = &result[0]; @@ -902,22 +1021,18 @@ mod tests { ]; // Test String format - let result_string = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::String, - ) - .unwrap(); + let result_string = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) + .unwrap(); assert_eq!(result_string.len(), 2); assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text"); assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); // Test OpenAI format - let result_openai = GrpcRouter::transform_messages_for_content_format( - &messages, - ChatTemplateContentFormat::OpenAI, - ) - .unwrap(); + let result_openai = + GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI) + .unwrap(); assert_eq!(result_openai.len(), 2); assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text"); diff --git a/sgl-router/src/tokenizer/chat_template.rs b/sgl-router/src/tokenizer/chat_template.rs index 798ede015..dec38cf59 100644 --- a/sgl-router/src/tokenizer/chat_template.rs +++ b/sgl-router/src/tokenizer/chat_template.rs @@ -6,6 +6,7 @@ use anyhow::{anyhow, Result}; use minijinja::{context, machinery, Environment, Value}; use serde_json; +use std::collections::HashMap; /// Chat template content format #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -288,21 +289,25 @@ fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool { matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number()) } +/// Parameters for chat template application +#[derive(Default)] +pub struct ChatTemplateParams<'a> { + pub add_generation_prompt: bool, + pub continue_final_message: bool, + pub tools: Option<&'a [serde_json::Value]>, + pub documents: Option<&'a [serde_json::Value]>, + pub template_kwargs: Option<&'a HashMap>, +} + /// Chat template processor using Jinja2 - simple wrapper like HuggingFace pub struct ChatTemplateProcessor { template: String, - bos_token: Option, - eos_token: Option, } impl ChatTemplateProcessor { /// Create a new chat template processor - pub fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - ChatTemplateProcessor { - template, - bos_token, - eos_token, - } + pub fn new(template: String) -> Self { + ChatTemplateProcessor { template } } /// Apply the chat template to a list of messages @@ -313,8 +318,12 @@ impl ChatTemplateProcessor { pub fn apply_chat_template( &self, messages: &[serde_json::Value], - add_generation_prompt: bool, + params: ChatTemplateParams, ) -> Result { + // Validate incompatible options + if params.continue_final_message && params.add_generation_prompt { + return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead.")); + } let mut env = Environment::new(); // Register the template @@ -326,17 +335,29 @@ impl ChatTemplateProcessor { .get_template("chat") .map_err(|e| anyhow!("Failed to get template: {}", e))?; - // Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump() + // Convert messages to minijinja::Value (messages already processed by router) let minijinja_messages: Vec = messages.iter().map(Value::from_serialize).collect(); - // Render the template directly with the provided values + let base_context = context! { + messages => &minijinja_messages, + add_generation_prompt => params.add_generation_prompt, + tools => params.tools, + documents => params.documents, + }; + + // Merge with template_kwargs if provided + let ctx = if let Some(kwargs) = params.template_kwargs { + context! { + ..base_context, + ..Value::from_serialize(kwargs) + } + } else { + base_context + }; + + // Render the template let rendered = tmpl - .render(context! { - messages => minijinja_messages, - add_generation_prompt => add_generation_prompt, - bos_token => self.bos_token.clone().unwrap_or_default(), - eos_token => self.eos_token.clone().unwrap_or_default() - }) + .render(&ctx) .map_err(|e| anyhow!("Failed to render template: {}", e))?; Ok(rendered) diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index f4d926621..396ccdf60 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -4,7 +4,8 @@ use anyhow::{Error, Result}; use tokenizers::tokenizer::Tokenizer as HfTokenizer; use super::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, }; use super::traits::{ Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, @@ -165,16 +166,11 @@ impl HuggingFaceTokenizer { pub fn apply_chat_template( &self, messages: &[serde_json::Value], - add_generation_prompt: bool, + params: ChatTemplateParams, ) -> Result { if let Some(ref template) = self.chat_template { - let processor = ChatTemplateProcessor::new( - template.clone(), - self.special_tokens.bos_token.clone(), - self.special_tokens.eos_token.clone(), - ); - - processor.apply_chat_template(messages, add_generation_prompt) + let processor = ChatTemplateProcessor::new(template.clone()); + processor.apply_chat_template(messages, params) } else { Err(Error::msg( "Cannot use chat template functions because tokenizer.chat_template is not set and no template \ diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs index cfca6ff8e..7a1ffa0fa 100644 --- a/sgl-router/tests/chat_template_format_detection.rs +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -1,6 +1,7 @@ use sglang_router_rs::protocols::spec; use sglang_router_rs::tokenizer::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, }; #[test] @@ -169,11 +170,7 @@ assistant: {%- endif %} "#; - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("".to_string()), - Some("".to_string()), - ); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = vec![ spec::ChatMessage::System { @@ -194,8 +191,12 @@ assistant: .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); + let params = ChatTemplateParams { + add_generation_prompt: true, + ..Default::default() + }; let result = processor - .apply_chat_template(&message_values, true) + .apply_chat_template(&message_values, params) .unwrap(); assert!(result.contains("system: You are helpful")); assert!(result.contains("user: Hello")); @@ -204,19 +205,15 @@ assistant: #[test] fn test_chat_template_with_tokens_unit_test() { - // Template that uses special tokens + // Template that uses template kwargs for tokens (more realistic) let template = r#" -{{ bos_token }} +{%- if start_token -%}{{ start_token }}{%- endif -%} {%- for message in messages -%} -{{ message.role }}: {{ message.content }}{{ eos_token }} +{{ message.role }}: {{ message.content }}{%- if end_token -%}{{ end_token }}{%- endif -%} {% endfor -%} "#; - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("".to_string()), - Some("".to_string()), - ); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { role: "user".to_string(), @@ -230,8 +227,24 @@ fn test_chat_template_with_tokens_unit_test() { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); + // Use template_kwargs to pass tokens + let mut template_kwargs = std::collections::HashMap::new(); + template_kwargs.insert( + "start_token".to_string(), + serde_json::Value::String("".to_string()), + ); + template_kwargs.insert( + "end_token".to_string(), + serde_json::Value::String("".to_string()), + ); + + let params = ChatTemplateParams { + template_kwargs: Some(&template_kwargs), + ..Default::default() + }; + let result = processor - .apply_chat_template(&message_values, false) + .apply_chat_template(&message_values, params) .unwrap(); assert!(result.contains("")); assert!(result.contains("")); diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs index 95cea27c0..4bdfe2008 100644 --- a/sgl-router/tests/chat_template_integration.rs +++ b/sgl-router/tests/chat_template_integration.rs @@ -1,6 +1,7 @@ use sglang_router_rs::protocols::spec; use sglang_router_rs::tokenizer::chat_template::{ - detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, + detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, + ChatTemplateProcessor, }; #[test] @@ -14,11 +15,7 @@ fn test_simple_chat_template() { {%- endif %} "#; - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("".to_string()), - Some("".to_string()), - ); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { role: "user".to_string(), @@ -32,8 +29,12 @@ fn test_simple_chat_template() { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); + let params = ChatTemplateParams { + add_generation_prompt: true, + ..Default::default() + }; let result = processor - .apply_chat_template(&message_values, true) + .apply_chat_template(&message_values, params) .unwrap(); assert!(result.contains("<|user|>Test<|end|>")); assert!(result.contains("<|assistant|>")); @@ -41,19 +42,15 @@ fn test_simple_chat_template() { #[test] fn test_chat_template_with_tokens() { - // Template that uses special tokens + // Template that uses template kwargs for tokens let template = r#" -{{ bos_token }} +{%- if bos_token -%}{{ bos_token }}{%- endif -%} {%- for message in messages -%} -{{ message.role }}: {{ message.content }}{{ eos_token }} +{{ message.role }}: {{ message.content }}{%- if eos_token -%}{{ eos_token }}{%- endif -%} {% endfor -%} "#; - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("".to_string()), - Some("".to_string()), - ); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { role: "user".to_string(), @@ -67,8 +64,24 @@ fn test_chat_template_with_tokens() { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); + // Use template_kwargs to pass tokens + let mut template_kwargs = std::collections::HashMap::new(); + template_kwargs.insert( + "bos_token".to_string(), + serde_json::Value::String("".to_string()), + ); + template_kwargs.insert( + "eos_token".to_string(), + serde_json::Value::String("".to_string()), + ); + + let params = ChatTemplateParams { + template_kwargs: Some(&template_kwargs), + ..Default::default() + }; + let result = processor - .apply_chat_template(&message_values, false) + .apply_chat_template(&message_values, params) .unwrap(); assert!(result.contains("")); assert!(result.contains("")); @@ -85,7 +98,7 @@ fn test_llama_style_template() { {%- set system_message = '' -%} {%- endif -%} -{{- bos_token }} +{{- bos_token if bos_token else '<|begin_of_text|>' }} {%- if system_message %} {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }} {%- endif %} @@ -99,11 +112,7 @@ fn test_llama_style_template() { {%- endif %} "#; - let processor = ChatTemplateProcessor::new( - template.to_string(), - Some("<|begin_of_text|>".to_string()), - Some("<|end_of_text|>".to_string()), - ); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = vec![ spec::ChatMessage::System { @@ -124,7 +133,21 @@ fn test_llama_style_template() { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); - let result = processor.apply_chat_template(&json_messages, true).unwrap(); + // Use template_kwargs to pass the token + let mut template_kwargs = std::collections::HashMap::new(); + template_kwargs.insert( + "bos_token".to_string(), + serde_json::Value::String("<|begin_of_text|>".to_string()), + ); + + let params = ChatTemplateParams { + add_generation_prompt: true, + template_kwargs: Some(&template_kwargs), + ..Default::default() + }; + let result = processor + .apply_chat_template(&json_messages, params) + .unwrap(); // Check that the result contains expected markers assert!(result.contains("<|begin_of_text|>")); @@ -147,7 +170,7 @@ fn test_chatml_template() { {%- endif %} "#; - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = vec![ spec::ChatMessage::User { @@ -176,7 +199,15 @@ fn test_chatml_template() { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); - let result = processor.apply_chat_template(&json_messages, true).unwrap(); + let result = processor + .apply_chat_template( + &json_messages, + ChatTemplateParams { + add_generation_prompt: true, + ..Default::default() + }, + ) + .unwrap(); // Check ChatML format assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); @@ -196,7 +227,7 @@ assistant: {%- endif -%} "#; - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { role: "user".to_string(), @@ -212,12 +243,20 @@ assistant: // Test without generation prompt let result = processor - .apply_chat_template(&json_messages, false) + .apply_chat_template(&json_messages, ChatTemplateParams::default()) .unwrap(); assert_eq!(result.trim(), "user: Test"); // Test with generation prompt - let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap(); + let result_with_prompt = processor + .apply_chat_template( + &json_messages, + ChatTemplateParams { + add_generation_prompt: true, + ..Default::default() + }, + ) + .unwrap(); assert!(result_with_prompt.contains("assistant:")); } @@ -225,10 +264,12 @@ assistant: fn test_empty_messages_template() { let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages: Vec = vec![]; - let result = processor.apply_chat_template(&messages, false).unwrap(); + let result = processor + .apply_chat_template(&messages, ChatTemplateParams::default()) + .unwrap(); assert_eq!(result, ""); } @@ -279,7 +320,7 @@ fn test_template_with_multimodal_content() { {% endfor %} "#; - let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { role: "user".to_string(), @@ -304,7 +345,7 @@ fn test_template_with_multimodal_content() { .collect(); let result = processor - .apply_chat_template(&json_messages, false) + .apply_chat_template(&json_messages, ChatTemplateParams::default()) .unwrap(); // Should contain both text and image parts diff --git a/sgl-router/tests/chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs index 7a03337fc..ac60a6867 100644 --- a/sgl-router/tests/chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests { use sglang_router_rs::protocols::spec; + use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use std::fs; use tempfile::TempDir; @@ -79,7 +80,14 @@ mod tests { .map(|msg| serde_json::to_value(msg).unwrap()) .collect(); - let result = tokenizer.apply_chat_template(&json_messages, true).unwrap(); + use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams; + let params = ChatTemplateParams { + add_generation_prompt: true, + ..Default::default() + }; + let result = tokenizer + .apply_chat_template(&json_messages, params) + .unwrap(); // Verify the custom template format assert!(result.contains("<|user|>Hello")); @@ -150,7 +158,7 @@ mod tests { .collect(); let result = tokenizer - .apply_chat_template(&json_messages, false) + .apply_chat_template(&json_messages, ChatTemplateParams::default()) .unwrap(); // Should use CUSTOM template, not built-in @@ -219,7 +227,7 @@ mod tests { .collect(); let result = tokenizer - .apply_chat_template(&json_messages, false) + .apply_chat_template(&json_messages, ChatTemplateParams::default()) .unwrap(); assert!(result.starts_with("NEW:"));