From 5fbad308cdbc9702ee1c4e8843016a5c2716bcc1 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 19 Aug 2025 20:14:02 -0700 Subject: [PATCH] [router] add tokenizer chat template support (#9370) Co-authored-by: Chang Su --- sgl-router/Cargo.toml | 6 +- sgl-router/src/tokenizer/chat_template.rs | 188 ++++++++++++++++++ sgl-router/src/tokenizer/factory.rs | 13 +- sgl-router/src/tokenizer/huggingface.rs | 154 ++++++++++---- sgl-router/src/tokenizer/mod.rs | 24 ++- sgl-router/src/tokenizer/sequence.rs | 10 +- sgl-router/src/tokenizer/stop.rs | 23 ++- sgl-router/src/tokenizer/stream.rs | 8 +- sgl-router/src/tokenizer/tiktoken.rs | 16 +- sgl-router/src/tokenizer/traits.rs | 49 +++-- sgl-router/tests/test_chat_template.rs | 156 +++++++++++++++ .../tests/test_chat_template_loading.rs | 186 +++++++++++++++++ 12 files changed, 748 insertions(+), 85 deletions(-) create mode 100644 sgl-router/src/tokenizer/chat_template.rs create mode 100644 sgl-router/tests/test_chat_template.rs create mode 100644 sgl-router/tests/test_chat_template_loading.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 1b20f3cba..3a1e8292e 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [features] default = ["huggingface", "grpc-client"] -huggingface = ["tokenizers"] +huggingface = ["tokenizers", "minijinja"] tiktoken = ["tiktoken-rs"] grpc-client = [] grpc-server = [] @@ -52,7 +52,8 @@ url = "2.5.4" tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" tokenizers = { version = "0.21.4", optional = true } -tiktoken-rs = { version = "0.5", optional = true } +tiktoken-rs = { version = "0.7.0", optional = true } +minijinja = { version = "2.0", optional = true } # gRPC and Protobuf dependencies tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } @@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } http-body-util = "0.1" portpicker = "0.1" +tempfile = "3.8" [[bench]] name = "request_processing" diff --git a/sgl-router/src/tokenizer/chat_template.rs b/sgl-router/src/tokenizer/chat_template.rs new file mode 100644 index 000000000..91ba55f60 --- /dev/null +++ b/sgl-router/src/tokenizer/chat_template.rs @@ -0,0 +1,188 @@ +//! Chat template support for tokenizers using Jinja2 templates +//! +//! This module provides functionality to apply chat templates to messages, +//! similar to HuggingFace transformers' apply_chat_template method. + +use anyhow::{anyhow, Result}; +#[cfg(feature = "huggingface")] +use minijinja::{context, Environment, Value}; +use serde::{Deserialize, Serialize}; +use serde_json; + +/// Represents a chat message with role and content +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl ChatMessage { + pub fn new(role: impl Into, content: impl Into) -> Self { + ChatMessage { + role: role.into(), + content: content.into(), + } + } + + pub fn system(content: impl Into) -> Self { + Self::new("system", content) + } + + pub fn user(content: impl Into) -> Self { + Self::new("user", content) + } + + pub fn assistant(content: impl Into) -> Self { + Self::new("assistant", content) + } +} + +/// Chat template processor using Jinja2 +#[cfg(feature = "huggingface")] +pub struct ChatTemplateProcessor { + template: String, + bos_token: Option, + eos_token: Option, +} + +#[cfg(feature = "huggingface")] +impl ChatTemplateProcessor { + /// Create a new chat template processor + pub fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + ChatTemplateProcessor { + template, + bos_token, + eos_token, + } + } + + /// Apply the chat template to a list of messages + /// + /// This mimics the behavior of HuggingFace's apply_chat_template method + /// but returns the formatted string instead of token IDs. + pub fn apply_chat_template( + &self, + messages: &[ChatMessage], + add_generation_prompt: bool, + ) -> Result { + let mut env = Environment::new(); + + // Register the template + env.add_template("chat", &self.template) + .map_err(|e| anyhow!("Failed to add template: {}", e))?; + + // Get the template + let tmpl = env + .get_template("chat") + .map_err(|e| anyhow!("Failed to get template: {}", e))?; + + // Convert messages to a format Jinja can work with + let messages_value: Vec = messages + .iter() + .map(|msg| { + context! { + role => msg.role.clone(), + content => msg.content.clone() + } + }) + .collect(); + + // Render the template + let rendered = tmpl + .render(context! { + messages => messages_value, + add_generation_prompt => add_generation_prompt, + bos_token => self.bos_token.clone().unwrap_or_default(), + eos_token => self.eos_token.clone().unwrap_or_default() + }) + .map_err(|e| anyhow!("Failed to render template: {}", e))?; + + Ok(rendered) + } +} + +/// Load chat template from tokenizer config JSON +#[cfg(feature = "huggingface")] +pub fn load_chat_template_from_config(config_path: &str) -> Result> { + use std::fs; + + let content = fs::read_to_string(config_path)?; + let config: serde_json::Value = serde_json::from_str(&content)?; + + // Look for chat_template in the config + if let Some(template) = config.get("chat_template") { + if let Some(template_str) = template.as_str() { + return Ok(Some(template_str.to_string())); + } + } + + Ok(None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chat_message_creation() { + let msg = ChatMessage::system("You are a helpful assistant"); + assert_eq!(msg.role, "system"); + assert_eq!(msg.content, "You are a helpful assistant"); + + let user_msg = ChatMessage::user("Hello!"); + assert_eq!(user_msg.role, "user"); + + let assistant_msg = ChatMessage::assistant("Hi there!"); + assert_eq!(assistant_msg.role, "assistant"); + } + + #[cfg(feature = "huggingface")] + #[test] + fn test_simple_chat_template() { + // Simple template that formats messages + let template = r#" +{%- for message in messages -%} +{{ message.role }}: {{ message.content }} +{% endfor -%} +{%- if add_generation_prompt -%} +assistant: +{%- endif -%} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("Hello"), + ]; + + let result = processor.apply_chat_template(&messages, true).unwrap(); + assert!(result.contains("system: You are helpful")); + assert!(result.contains("user: Hello")); + assert!(result.contains("assistant:")); + } + + #[cfg(feature = "huggingface")] + #[test] + fn test_chat_template_with_tokens() { + // Template that uses special tokens + let template = r#" +{{ bos_token }} +{%- for message in messages -%} +{{ message.role }}: {{ message.content }}{{ eos_token }} +{% endfor -%} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = vec![ChatMessage::user("Test")]; + + let result = processor.apply_chat_template(&messages, false).unwrap(); + assert!(result.contains("")); + assert!(result.contains("")); + } +} diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index e339140e7..fb6bef510 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -26,6 +26,14 @@ pub enum TokenizerType { /// - json: HuggingFace tokenizer /// - For testing: can return mock tokenizer pub fn create_tokenizer_from_file(file_path: &str) -> Result> { + create_tokenizer_with_chat_template(file_path, None) +} + +/// Create a tokenizer from a file path with an optional chat template +pub fn create_tokenizer_with_chat_template( + file_path: &str, + chat_template_path: Option<&str>, +) -> Result> { let start_time = Instant::now(); // Special case for testing @@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result { #[cfg(feature = "huggingface")] { - let tokenizer = HuggingFaceTokenizer::from_file(file_path)?; + let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template( + file_path, + chat_template_path, + )?; TokenizerMetrics::record_factory_load("json"); TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size()); diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index ec07ce6d8..d6ccc0de1 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -1,21 +1,36 @@ -use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +use super::traits::{ + Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, +}; use crate::metrics::TokenizerMetrics; use anyhow::{Error, Result}; use std::collections::HashMap; use std::time::Instant; use tokenizers::tokenizer::Tokenizer as HfTokenizer; +#[cfg(feature = "minijinja")] +use super::chat_template::{ChatMessage, ChatTemplateProcessor}; + /// HuggingFace tokenizer wrapper pub struct HuggingFaceTokenizer { tokenizer: HfTokenizer, special_tokens: SpecialTokens, - vocab: HashMap, - reverse_vocab: HashMap, + vocab: HashMap, + reverse_vocab: HashMap, + #[cfg(feature = "minijinja")] + chat_template: Option, } impl HuggingFaceTokenizer { /// Create a tokenizer from a HuggingFace tokenizer JSON file pub fn from_file(file_path: &str) -> Result { + Self::from_file_with_chat_template(file_path, None) + } + + /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template + pub fn from_file_with_chat_template( + file_path: &str, + chat_template_path: Option<&str>, + ) -> Result { let tokenizer = HfTokenizer::from_file(file_path) .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?; @@ -24,16 +39,28 @@ impl HuggingFaceTokenizer { // Build vocab mappings let vocab = tokenizer.get_vocab(false); - let reverse_vocab: HashMap = vocab + let reverse_vocab: HashMap = vocab .iter() .map(|(token, &id)| (id, token.clone())) .collect(); + // Load chat template + #[cfg(feature = "minijinja")] + let chat_template = if let Some(template_path) = chat_template_path { + // Load from specified .jinja file + Self::load_chat_template_from_file(template_path)? + } else { + // Try to load from tokenizer_config.json + Self::load_chat_template(file_path) + }; + Ok(HuggingFaceTokenizer { tokenizer, special_tokens, vocab, reverse_vocab, + #[cfg(feature = "minijinja")] + chat_template, }) } @@ -41,7 +68,7 @@ impl HuggingFaceTokenizer { pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self { let special_tokens = Self::extract_special_tokens(&tokenizer); let vocab = tokenizer.get_vocab(false); - let reverse_vocab: HashMap = vocab + let reverse_vocab: HashMap = vocab .iter() .map(|(token, &id)| (id, token.clone())) .collect(); @@ -51,6 +78,8 @@ impl HuggingFaceTokenizer { special_tokens, vocab, reverse_vocab, + #[cfg(feature = "minijinja")] + chat_template: None, } } @@ -81,13 +110,86 @@ impl HuggingFaceTokenizer { } } + /// Try to load chat template from tokenizer_config.json + #[cfg(feature = "minijinja")] + fn load_chat_template(tokenizer_path: &str) -> Option { + // Try to find tokenizer_config.json in the same directory + let path = std::path::Path::new(tokenizer_path); + let dir = path.parent()?; + let config_path = dir.join("tokenizer_config.json"); + + if config_path.exists() { + if let Ok(template) = + super::chat_template::load_chat_template_from_config(config_path.to_str()?) + { + return template; + } + } + None + } + + /// Load chat template from a .jinja file + #[cfg(feature = "minijinja")] + fn load_chat_template_from_file(template_path: &str) -> Result> { + use std::fs; + + let content = fs::read_to_string(template_path) + .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?; + + // Clean up the template (similar to Python implementation) + let template = content.trim().replace("\\n", "\n"); + + Ok(Some(template)) + } + + /// Set or override the chat template + #[cfg(feature = "minijinja")] + pub fn set_chat_template(&mut self, template: String) { + self.chat_template = Some(template); + } + /// Apply chat template if available - pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result { - // This is a placeholder - actual implementation would handle templates + #[cfg(feature = "minijinja")] + pub fn apply_chat_template( + &self, + messages: &[ChatMessage], + add_generation_prompt: bool, + ) -> Result { + if let Some(ref template) = self.chat_template { + let processor = ChatTemplateProcessor::new( + template.clone(), + self.special_tokens.bos_token.clone(), + self.special_tokens.eos_token.clone(), + ); + processor.apply_chat_template(messages, add_generation_prompt) + } else { + // Fallback to simple formatting if no template is available + let mut result = String::new(); + for msg in messages { + result.push_str(&format!("{}: {}\n", msg.role, msg.content)); + } + if add_generation_prompt { + result.push_str("assistant: "); + } + Ok(result) + } + } + + /// Apply chat template if available (without minijinja feature) + #[cfg(not(feature = "minijinja"))] + pub fn apply_chat_template( + &self, + messages: &[ChatMessage], + add_generation_prompt: bool, + ) -> Result { + // Fallback to simple formatting let mut result = String::new(); for msg in messages { result.push_str(&format!("{}: {}\n", msg.role, msg.content)); } + if add_generation_prompt { + result.push_str("assistant: "); + } Ok(result) } } @@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer { } impl Decoder for HuggingFaceTokenizer { - fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result { let start = Instant::now(); TokenizerMetrics::record_decode_request("huggingface"); @@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer { &self.special_tokens } - fn token_to_id(&self, token: &str) -> Option { + fn token_to_id(&self, token: &str) -> Option { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: TokenIdType) -> Option { self.reverse_vocab.get(&id).cloned() } } -/// Represents a chat message for template application -#[derive(Debug, Clone)] -pub struct ChatMessage { - pub role: String, - pub content: String, -} - -impl ChatMessage { - pub fn new(role: impl Into, content: impl Into) -> Self { - ChatMessage { - role: role.into(), - content: content.into(), - } - } - - pub fn system(content: impl Into) -> Self { - Self::new("system", content) - } - - pub fn user(content: impl Into) -> Self { - Self::new("user", content) - } - - pub fn assistant(content: impl Into) -> Self { - Self::new("assistant", content) - } -} - #[cfg(test)] mod tests { - use super::*; + #[cfg(feature = "minijinja")] + use super::ChatMessage; + #[cfg(feature = "minijinja")] #[test] fn test_chat_message_creation() { let msg = ChatMessage::system("You are a helpful assistant"); diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index 7d7f87aed..78632062b 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -10,6 +10,9 @@ pub mod stream; pub mod traits; // Feature-gated modules +#[cfg(feature = "huggingface")] +pub mod chat_template; + #[cfg(feature = "huggingface")] pub mod huggingface; @@ -20,14 +23,20 @@ pub mod tiktoken; mod tests; // Re-exports -pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType}; +pub use factory::{ + create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template, + TokenizerType, +}; pub use sequence::Sequence; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stream::DecodeStream; pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; #[cfg(feature = "huggingface")] -pub use huggingface::{ChatMessage, HuggingFaceTokenizer}; +pub use huggingface::HuggingFaceTokenizer; + +#[cfg(feature = "huggingface")] +pub use chat_template::ChatMessage; #[cfg(feature = "tiktoken")] pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; @@ -42,6 +51,17 @@ impl Tokenizer { Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?)) } + /// Create a tokenizer from a file path with an optional chat template + pub fn from_file_with_chat_template( + file_path: &str, + chat_template_path: Option<&str>, + ) -> Result { + Ok(Tokenizer(factory::create_tokenizer_with_chat_template( + file_path, + chat_template_path, + )?)) + } + /// Create a tokenizer from an Arc pub fn from_arc(tokenizer: Arc) -> Self { Tokenizer(tokenizer) diff --git a/sgl-router/src/tokenizer/sequence.rs b/sgl-router/src/tokenizer/sequence.rs index 816d3cc59..99801438d 100644 --- a/sgl-router/src/tokenizer/sequence.rs +++ b/sgl-router/src/tokenizer/sequence.rs @@ -1,4 +1,4 @@ -use super::traits::Tokenizer as TokenizerTrait; +use super::traits::{TokenIdType, Tokenizer as TokenizerTrait}; use anyhow::Result; use std::sync::Arc; @@ -9,7 +9,7 @@ pub struct Sequence { tokenizer: Arc, /// The current sequence of token ids - token_ids: Vec, + token_ids: Vec, /// The position in the current sequence the last decoded token completed prefix_offset: usize, @@ -54,7 +54,7 @@ impl Sequence { } /// Create a sequence with initial tokens - pub fn with_tokens(tokenizer: Arc, token_ids: Vec) -> Self { + pub fn with_tokens(tokenizer: Arc, token_ids: Vec) -> Self { let len = token_ids.len(); Self { tokenizer, @@ -90,7 +90,7 @@ impl Sequence { /// Append a single token to the sequence and return newly decoded text /// Based on HuggingFace TGI incremental decoding - pub fn append_token(&mut self, token_id: u32) -> Result { + pub fn append_token(&mut self, token_id: TokenIdType) -> Result { // Store the old read offset before adding the new token let old_read_offset = self.read_offset; @@ -145,7 +145,7 @@ impl Sequence { } /// Get the current token ids - pub fn token_ids(&self) -> &[u32] { + pub fn token_ids(&self) -> &[TokenIdType] { &self.token_ids } diff --git a/sgl-router/src/tokenizer/stop.rs b/sgl-router/src/tokenizer/stop.rs index 96a6d4c9e..69376e20b 100644 --- a/sgl-router/src/tokenizer/stop.rs +++ b/sgl-router/src/tokenizer/stop.rs @@ -1,4 +1,4 @@ -use super::traits; +use super::traits::{self, TokenIdType}; use crate::metrics::TokenizerMetrics; use anyhow::Result; use std::collections::HashSet; @@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput { #[derive(Debug, Clone, Default)] pub struct StopSequenceConfig { /// Token IDs that trigger a stop - pub stop_tokens: HashSet, + pub stop_tokens: HashSet, /// String sequences that trigger a stop pub stop_sequences: Vec, /// Token IDs for visible stops (included in output) - pub visible_stop_tokens: HashSet, + pub visible_stop_tokens: HashSet, /// String sequences for visible stops (included in output) pub visible_stop_sequences: Vec, } impl StopSequenceConfig { /// Builder pattern - add a stop token - pub fn with_stop_token(mut self, token_id: u32) -> Self { + pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self { self.stop_tokens.insert(token_id); self } @@ -45,7 +45,7 @@ impl StopSequenceConfig { } /// Builder pattern - add a visible stop token - pub fn with_visible_stop_token(mut self, token_id: u32) -> Self { + pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self { self.visible_stop_tokens.insert(token_id); self } @@ -64,7 +64,7 @@ pub struct StopSequenceDecoder { /// Buffer for partial matches (the "jail") jail_buffer: String, /// Accumulated tokens - token_buffer: Vec, + token_buffer: Vec, /// Offset where the prefix text starts (for context) prefix_offset: usize, /// Offset marking the end of previously decoded text @@ -94,7 +94,7 @@ impl StopSequenceDecoder { } /// Process a single token - pub fn process_token(&mut self, token_id: u32) -> Result { + pub fn process_token(&mut self, token_id: TokenIdType) -> Result { let start = Instant::now(); if self.stopped { @@ -252,7 +252,10 @@ impl StopSequenceDecoder { } /// Process multiple tokens - pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result> { + pub fn process_tokens( + &mut self, + token_ids: &[TokenIdType], + ) -> Result> { let mut outputs = Vec::new(); for &token_id in token_ids { outputs.push(self.process_token(token_id)?); @@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder { } } - pub fn stop_token(mut self, token_id: u32) -> Self { + pub fn stop_token(mut self, token_id: TokenIdType) -> Self { self.config.stop_tokens.insert(token_id); self } @@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder { self } - pub fn visible_stop_token(mut self, token_id: u32) -> Self { + pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self { self.config.visible_stop_tokens.insert(token_id); self } diff --git a/sgl-router/src/tokenizer/stream.rs b/sgl-router/src/tokenizer/stream.rs index 8ff3abe28..bea7ede8d 100644 --- a/sgl-router/src/tokenizer/stream.rs +++ b/sgl-router/src/tokenizer/stream.rs @@ -1,6 +1,6 @@ // src/tokenizer/stream.rs -use super::traits; +use super::traits::{self, TokenIdType}; use crate::metrics::TokenizerMetrics; use anyhow::Result; use std::sync::Arc; @@ -18,7 +18,7 @@ pub struct DecodeStream { /// A temporary buffer of the necessary token_ids needed /// to produce valid string chunks - all_token_ids: Vec, + all_token_ids: Vec, prefix_offset: usize, read_offset: usize, @@ -27,7 +27,7 @@ pub struct DecodeStream { impl DecodeStream { pub fn new( tokenizer: Arc, - prompt_token_ids: &[u32], + prompt_token_ids: &[TokenIdType], skip_special_tokens: bool, ) -> Self { let num_input_tokens = prompt_token_ids.len(); @@ -44,7 +44,7 @@ impl DecodeStream { /// Step appends a token_id to the internal state and tries to produce a text chunk. /// Returning `None` means the given id is not enough to produce a chunk. - pub fn step(&mut self, id: u32) -> Result> { + pub fn step(&mut self, id: TokenIdType) -> Result> { let start = Instant::now(); self.all_token_ids.push(id); diff --git a/sgl-router/src/tokenizer/tiktoken.rs b/sgl-router/src/tokenizer/tiktoken.rs index 4cf0ea9f1..9ba49ec9a 100644 --- a/sgl-router/src/tokenizer/tiktoken.rs +++ b/sgl-router/src/tokenizer/tiktoken.rs @@ -1,4 +1,6 @@ -use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +use super::traits::{ + Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, +}; use anyhow::{Error, Result}; use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; @@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer { } impl Decoder for TiktokenTokenizer { - fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result { - // Convert u32 to usize for tiktoken-rs - let tokens: Vec = token_ids.iter().map(|&id| id as usize).collect(); - + fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result { + // tiktoken-rs 0.7.0 now uses u32 (Rank type) self.tokenizer - .decode(tokens) + .decode(token_ids.to_vec()) .map_err(|e| Error::msg(format!("Decoding failed: {}", e))) } } @@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer { &self.special_tokens } - fn token_to_id(&self, _token: &str) -> Option { + fn token_to_id(&self, _token: &str) -> Option { // Tiktoken doesn't provide direct token-to-id mapping // We'd need to encode the token and check if it produces a single ID None } - fn id_to_token(&self, _id: u32) -> Option { + fn id_to_token(&self, _id: TokenIdType) -> Option { // Tiktoken doesn't provide direct id-to-token mapping // We can only decode IDs to text None diff --git a/sgl-router/src/tokenizer/traits.rs b/sgl-router/src/tokenizer/traits.rs index e0153704a..5bf68c240 100644 --- a/sgl-router/src/tokenizer/traits.rs +++ b/sgl-router/src/tokenizer/traits.rs @@ -1,4 +1,9 @@ use anyhow::Result; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +/// Type alias for token IDs +pub type TokenIdType = u32; /// Core encoding trait - separate from decoding for modularity pub trait Encoder: Send + Sync { @@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync { /// Core decoding trait - can be implemented independently pub trait Decoder: Send + Sync { - fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result; + fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result; } /// Combined tokenizer trait pub trait Tokenizer: Encoder + Decoder { fn vocab_size(&self) -> usize; fn get_special_tokens(&self) -> &SpecialTokens; - fn token_to_id(&self, token: &str) -> Option; - fn id_to_token(&self, id: u32) -> Option; + fn token_to_id(&self, token: &str) -> Option; + fn id_to_token(&self, id: TokenIdType) -> Option; } /// Contains the results of tokenizing text: token IDs, string tokens, and their spans @@ -25,29 +30,45 @@ pub enum Encoding { /// Hugging Face Hf(Box), /// Sentence Piece - Sp(Vec), - /// Tiktoken (for GPT models) - Tiktoken(Vec), + Sp(Vec), + /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0 + Tiktoken(Vec), } impl Encoding { - pub fn token_ids(&self) -> Vec { + /// Returns a reference to token IDs when possible, owned Vec for compatibility + pub fn token_ids(&self) -> Vec { match self { Encoding::Hf(inner) => inner.get_ids().to_vec(), Encoding::Sp(inner) => inner.clone(), - Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(), + Encoding::Tiktoken(inner) => inner.clone(), } } - pub fn token_ids_ref(&self) -> &[u32] { + /// Returns a reference to token IDs where possible + pub fn token_ids_ref(&self) -> &[TokenIdType] { match self { Encoding::Hf(inner) => inner.get_ids(), Encoding::Sp(inner) => inner, - Encoding::Tiktoken(_) => { - // Tiktoken uses usize, we can't return a reference to u32 - // This is a limitation - callers should use token_ids() for Tiktoken - &[] - } + Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0! + } + } + + /// Get a hash of the token IDs for caching purposes + pub fn get_hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); + hasher.finish() + } +} + +/// Hash implementation for Encoding +impl Hash for Encoding { + fn hash(&self, state: &mut H) { + match self { + Encoding::Hf(inner) => inner.get_ids().hash(state), + Encoding::Sp(inner) => inner.hash(state), + Encoding::Tiktoken(inner) => inner.hash(state), } } } diff --git a/sgl-router/tests/test_chat_template.rs b/sgl-router/tests/test_chat_template.rs new file mode 100644 index 000000000..c9fea45ed --- /dev/null +++ b/sgl-router/tests/test_chat_template.rs @@ -0,0 +1,156 @@ +#[cfg(test)] +mod tests { + use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor}; + + #[test] + #[cfg(feature = "huggingface")] + fn test_chat_message_helpers() { + let system_msg = ChatMessage::system("You are a helpful assistant"); + assert_eq!(system_msg.role, "system"); + assert_eq!(system_msg.content, "You are a helpful assistant"); + + let user_msg = ChatMessage::user("Hello!"); + assert_eq!(user_msg.role, "user"); + assert_eq!(user_msg.content, "Hello!"); + + let assistant_msg = ChatMessage::assistant("Hi there!"); + assert_eq!(assistant_msg.role, "assistant"); + assert_eq!(assistant_msg.content, "Hi there!"); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_llama_style_template() { + // Test a Llama-style chat template + let template = r#" +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {%- set system_message = '' -%} +{%- endif -%} + +{{- bos_token }} +{%- if system_message %} +{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }} +{%- endif %} + +{%- for message in messages %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("<|begin_of_text|>".to_string()), + Some("<|end_of_text|>".to_string()), + ); + + let messages = vec![ + ChatMessage::system("You are a helpful assistant"), + ChatMessage::user("What is 2+2?"), + ]; + + let result = processor.apply_chat_template(&messages, true).unwrap(); + + // Check that the result contains expected markers + assert!(result.contains("<|begin_of_text|>")); + assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); + assert!(result.contains("You are a helpful assistant")); + assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); + assert!(result.contains("What is 2+2?")); + assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>")); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_chatml_template() { + // Test a ChatML-style template + let template = r#" +{%- for message in messages %} + {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = vec![ + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there!"), + ChatMessage::user("How are you?"), + ]; + + let result = processor.apply_chat_template(&messages, true).unwrap(); + + // Check ChatML format + assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); + assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>")); + assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>")); + assert!(result.ends_with("<|im_start|>assistant\n")); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_template_without_generation_prompt() { + let template = r#" +{%- for message in messages -%} +{{ message.role }}: {{ message.content }} +{% endfor -%} +{%- if add_generation_prompt -%} +assistant: +{%- endif -%} +"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = vec![ChatMessage::user("Test")]; + + // Test without generation prompt + let result = processor.apply_chat_template(&messages, false).unwrap(); + assert_eq!(result.trim(), "user: Test"); + + // Test with generation prompt + let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap(); + assert!(result_with_prompt.contains("assistant:")); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_template_with_special_tokens() { + let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#; + + let processor = ChatTemplateProcessor::new( + template.to_string(), + Some("".to_string()), + Some("".to_string()), + ); + + let messages = vec![ChatMessage::user("Hello")]; + + let result = processor.apply_chat_template(&messages, false).unwrap(); + assert_eq!(result, "Hello"); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_empty_messages() { + let template = + r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; + + let processor = ChatTemplateProcessor::new(template.to_string(), None, None); + + let messages = vec![]; + let result = processor.apply_chat_template(&messages, false).unwrap(); + assert_eq!(result, ""); + } + + // Integration test with actual tokenizer file loading would go here + // but requires a real tokenizer_config.json file +} diff --git a/sgl-router/tests/test_chat_template_loading.rs b/sgl-router/tests/test_chat_template_loading.rs new file mode 100644 index 000000000..235c608e8 --- /dev/null +++ b/sgl-router/tests/test_chat_template_loading.rs @@ -0,0 +1,186 @@ +#[cfg(test)] +mod tests { + use std::fs; + use tempfile::TempDir; + + #[test] + #[cfg(feature = "huggingface")] + fn test_load_chat_template_from_file() { + use sglang_router_rs::tokenizer::chat_template::ChatMessage; + use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; + + // Create temporary directory + let temp_dir = TempDir::new().unwrap(); + let template_path = temp_dir.path().join("template.jinja"); + + // Write a test template + let template_content = r#" +{%- for message in messages %} + {{- '<|' + message['role'] + '|>' + message['content'] }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|assistant|>' }} +{%- endif %} +"#; + fs::write(&template_path, template_content).unwrap(); + + // Create a mock tokenizer config + let tokenizer_config = r#"{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "vocab": { + "hello": 0, + "world": 1, + "": 2, + "": 3 + }, + "merges": [] + } + }"#; + + let tokenizer_path = temp_dir.path().join("tokenizer.json"); + fs::write(&tokenizer_path, tokenizer_config).unwrap(); + + // Load tokenizer with custom chat template + let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template( + tokenizer_path.to_str().unwrap(), + Some(template_path.to_str().unwrap()), + ) + .unwrap(); + + // Test that the custom template is used + let messages = vec![ + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there"), + ]; + + let result = tokenizer.apply_chat_template(&messages, true).unwrap(); + + // Verify the custom template format + assert!(result.contains("<|user|>Hello")); + assert!(result.contains("<|assistant|>Hi there")); + assert!(result.ends_with("<|assistant|>")); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_override_existing_template() { + use sglang_router_rs::tokenizer::chat_template::ChatMessage; + use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; + + // Create temporary directory + let temp_dir = TempDir::new().unwrap(); + + // Create tokenizer config with a built-in template + let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json"); + let config_with_template = r#"{ + "chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}" + }"#; + fs::write(&tokenizer_config_path, config_with_template).unwrap(); + + // Create the actual tokenizer file + let tokenizer_json = r#"{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "vocab": { + "test": 0, + "": 1, + "": 2 + }, + "merges": [] + } + }"#; + let tokenizer_path = temp_dir.path().join("tokenizer.json"); + fs::write(&tokenizer_path, tokenizer_json).unwrap(); + + // Create custom template that should override + let custom_template_path = temp_dir.path().join("custom.jinja"); + let custom_template = + r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#; + fs::write(&custom_template_path, custom_template).unwrap(); + + // Load with custom template - should override the built-in one + let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template( + tokenizer_path.to_str().unwrap(), + Some(custom_template_path.to_str().unwrap()), + ) + .unwrap(); + + let messages = vec![ChatMessage::user("Test")]; + let result = tokenizer.apply_chat_template(&messages, false).unwrap(); + + // Should use CUSTOM template, not built-in + assert!(result.starts_with("CUSTOM:")); + assert!(result.contains("[user]: Test")); + assert!(!result.contains("built-in:")); + } + + #[test] + #[cfg(feature = "huggingface")] + fn test_set_chat_template_after_creation() { + use sglang_router_rs::tokenizer::chat_template::ChatMessage; + use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; + + // Create temporary directory and tokenizer file + let temp_dir = TempDir::new().unwrap(); + let tokenizer_json = r#"{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "vocab": { + "test": 0, + "": 1, + "": 2 + }, + "merges": [] + } + }"#; + let tokenizer_path = temp_dir.path().join("tokenizer.json"); + fs::write(&tokenizer_path, tokenizer_json).unwrap(); + + // Load tokenizer without custom template + let mut tokenizer = + HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap(); + + // Set a template after creation (mimics Python's behavior) + let new_template = + "NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}"; + tokenizer.set_chat_template(new_template.to_string()); + + let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")]; + let result = tokenizer.apply_chat_template(&messages, false).unwrap(); + + assert!(result.starts_with("NEW:")); + assert!(result.contains("user: Hello;")); + assert!(result.contains("assistant: World;")); + } +}