[router] add tokenizer chat template support (#9370)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -5,7 +5,7 @@ edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["huggingface", "grpc-client"]
|
||||
huggingface = ["tokenizers"]
|
||||
huggingface = ["tokenizers", "minijinja"]
|
||||
tiktoken = ["tiktoken-rs"]
|
||||
grpc-client = []
|
||||
grpc-server = []
|
||||
@@ -52,7 +52,8 @@ url = "2.5.4"
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
tokenizers = { version = "0.21.4", optional = true }
|
||||
tiktoken-rs = { version = "0.5", optional = true }
|
||||
tiktoken-rs = { version = "0.7.0", optional = true }
|
||||
minijinja = { version = "2.0", optional = true }
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||
@@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
http-body-util = "0.1"
|
||||
portpicker = "0.1"
|
||||
tempfile = "3.8"
|
||||
|
||||
[[bench]]
|
||||
name = "request_processing"
|
||||
|
||||
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Chat template support for tokenizers using Jinja2 templates
|
||||
//!
|
||||
//! This module provides functionality to apply chat templates to messages,
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
#[cfg(feature = "huggingface")]
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
/// Represents a chat message with role and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
impl ChatTemplateProcessor {
|
||||
/// Create a new chat template processor
|
||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
ChatTemplateProcessor {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the chat template to a list of messages
|
||||
///
|
||||
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||
/// but returns the formatted string instead of token IDs.
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
// Register the template
|
||||
env.add_template("chat", &self.template)
|
||||
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
||||
|
||||
// Get the template
|
||||
let tmpl = env
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert messages to a format Jinja can work with
|
||||
let messages_value: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
context! {
|
||||
role => msg.role.clone(),
|
||||
content => msg.content.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Render the template
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => messages_value,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
})
|
||||
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||
|
||||
Ok(rendered)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load chat template from tokenizer config JSON
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(config_path)?;
|
||||
let config: serde_json::Value = serde_json::from_str(&content)?;
|
||||
|
||||
// Look for chat_template in the config
|
||||
if let Some(template) = config.get("chat_template") {
|
||||
if let Some(template_str) = template.as_str() {
|
||||
return Ok(Some(template_str.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
// Simple template that formats messages
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,14 @@ pub enum TokenizerType {
|
||||
/// - json: HuggingFace tokenizer
|
||||
/// - For testing: can return mock tokenizer
|
||||
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
create_tokenizer_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn create_tokenizer_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Special case for testing
|
||||
@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
|
||||
Some("json") => {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?;
|
||||
|
||||
TokenizerMetrics::record_factory_load("json");
|
||||
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
tokenizer: HfTokenizer,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||
Self::from_file_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Self> {
|
||||
let tokenizer = HfTokenizer::from_file(file_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||
|
||||
@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
// Load chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
let chat_template = if let Some(template_path) = chat_template_path {
|
||||
// Load from specified .jinja file
|
||||
Self::load_chat_template_from_file(template_path)?
|
||||
} else {
|
||||
// Try to load from tokenizer_config.json
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to load chat template from tokenizer_config.json
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||
// Try to find tokenizer_config.json in the same directory
|
||||
let path = std::path::Path::new(tokenizer_path);
|
||||
let dir = path.parent()?;
|
||||
let config_path = dir.join("tokenizer_config.json");
|
||||
|
||||
if config_path.exists() {
|
||||
if let Ok(template) =
|
||||
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
|
||||
{
|
||||
return template;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Load chat template from a .jinja file
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(template_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
|
||||
|
||||
// Clean up the template (similar to Python implementation)
|
||||
let template = content.trim().replace("\\n", "\n");
|
||||
|
||||
Ok(Some(template))
|
||||
}
|
||||
|
||||
/// Set or override the chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
||||
// This is a placeholder - actual implementation would handle templates
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.clone(),
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template if available (without minijinja feature)
|
||||
#[cfg(not(feature = "minijinja"))]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
// Fallback to simple formatting
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
|
||||
}
|
||||
|
||||
impl Decoder for HuggingFaceTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||
let start = Instant::now();
|
||||
|
||||
TokenizerMetrics::record_decode_request("huggingface");
|
||||
@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a chat message for template application
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::ChatMessage;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
|
||||
@@ -10,6 +10,9 @@ pub mod stream;
|
||||
pub mod traits;
|
||||
|
||||
// Feature-gated modules
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub mod chat_template;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub mod huggingface;
|
||||
|
||||
@@ -20,14 +23,20 @@ pub mod tiktoken;
|
||||
mod tests;
|
||||
|
||||
// Re-exports
|
||||
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
||||
pub use factory::{
|
||||
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
||||
TokenizerType,
|
||||
};
|
||||
pub use sequence::Sequence;
|
||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub use chat_template::ChatMessage;
|
||||
|
||||
#[cfg(feature = "tiktoken")]
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
@@ -42,6 +51,17 @@ impl Tokenizer {
|
||||
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Tokenizer> {
|
||||
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::traits::Tokenizer as TokenizerTrait;
|
||||
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -9,7 +9,7 @@ pub struct Sequence {
|
||||
tokenizer: Arc<dyn TokenizerTrait>,
|
||||
|
||||
/// The current sequence of token ids
|
||||
token_ids: Vec<u32>,
|
||||
token_ids: Vec<TokenIdType>,
|
||||
|
||||
/// The position in the current sequence the last decoded token completed
|
||||
prefix_offset: usize,
|
||||
@@ -54,7 +54,7 @@ impl Sequence {
|
||||
}
|
||||
|
||||
/// Create a sequence with initial tokens
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||
let len = token_ids.len();
|
||||
Self {
|
||||
tokenizer,
|
||||
@@ -90,7 +90,7 @@ impl Sequence {
|
||||
|
||||
/// Append a single token to the sequence and return newly decoded text
|
||||
/// Based on HuggingFace TGI incremental decoding
|
||||
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
|
||||
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
|
||||
// Store the old read offset before adding the new token
|
||||
let old_read_offset = self.read_offset;
|
||||
|
||||
@@ -145,7 +145,7 @@ impl Sequence {
|
||||
}
|
||||
|
||||
/// Get the current token ids
|
||||
pub fn token_ids(&self) -> &[u32] {
|
||||
pub fn token_ids(&self) -> &[TokenIdType] {
|
||||
&self.token_ids
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::traits;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StopSequenceConfig {
|
||||
/// Token IDs that trigger a stop
|
||||
pub stop_tokens: HashSet<u32>,
|
||||
pub stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences that trigger a stop
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// Token IDs for visible stops (included in output)
|
||||
pub visible_stop_tokens: HashSet<u32>,
|
||||
pub visible_stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences for visible stops (included in output)
|
||||
pub visible_stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
impl StopSequenceConfig {
|
||||
/// Builder pattern - add a stop token
|
||||
pub fn with_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -45,7 +45,7 @@ impl StopSequenceConfig {
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop token
|
||||
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder {
|
||||
/// Buffer for partial matches (the "jail")
|
||||
jail_buffer: String,
|
||||
/// Accumulated tokens
|
||||
token_buffer: Vec<u32>,
|
||||
token_buffer: Vec<TokenIdType>,
|
||||
/// Offset where the prefix text starts (for context)
|
||||
prefix_offset: usize,
|
||||
/// Offset marking the end of previously decoded text
|
||||
@@ -94,7 +94,7 @@ impl StopSequenceDecoder {
|
||||
}
|
||||
|
||||
/// Process a single token
|
||||
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
|
||||
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
if self.stopped {
|
||||
@@ -252,7 +252,10 @@ impl StopSequenceDecoder {
|
||||
}
|
||||
|
||||
/// Process multiple tokens
|
||||
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
pub fn process_tokens(
|
||||
&mut self,
|
||||
token_ids: &[TokenIdType],
|
||||
) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
let mut outputs = Vec::new();
|
||||
for &token_id in token_ids {
|
||||
outputs.push(self.process_token(token_id)?);
|
||||
@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// src/tokenizer/stream.rs
|
||||
|
||||
use super::traits;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
@@ -18,7 +18,7 @@ pub struct DecodeStream {
|
||||
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks
|
||||
all_token_ids: Vec<u32>,
|
||||
all_token_ids: Vec<TokenIdType>,
|
||||
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
@@ -27,7 +27,7 @@ pub struct DecodeStream {
|
||||
impl DecodeStream {
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
prompt_token_ids: &[u32],
|
||||
prompt_token_ids: &[TokenIdType],
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
let num_input_tokens = prompt_token_ids.len();
|
||||
@@ -44,7 +44,7 @@ impl DecodeStream {
|
||||
|
||||
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
|
||||
let start = Instant::now();
|
||||
|
||||
self.all_token_ids.push(id);
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||
|
||||
@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer {
|
||||
}
|
||||
|
||||
impl Decoder for TiktokenTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
|
||||
// Convert u32 to usize for tiktoken-rs
|
||||
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
|
||||
|
||||
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
|
||||
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
|
||||
self.tokenizer
|
||||
.decode(tokens)
|
||||
.decode(token_ids.to_vec())
|
||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||
}
|
||||
}
|
||||
@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, _token: &str) -> Option<u32> {
|
||||
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
|
||||
// Tiktoken doesn't provide direct token-to-id mapping
|
||||
// We'd need to encode the token and check if it produces a single ID
|
||||
None
|
||||
}
|
||||
|
||||
fn id_to_token(&self, _id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
|
||||
// Tiktoken doesn't provide direct id-to-token mapping
|
||||
// We can only decode IDs to text
|
||||
None
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Type alias for token IDs
|
||||
pub type TokenIdType = u32;
|
||||
|
||||
/// Core encoding trait - separate from decoding for modularity
|
||||
pub trait Encoder: Send + Sync {
|
||||
@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
|
||||
|
||||
/// Core decoding trait - can be implemented independently
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Combined tokenizer trait
|
||||
pub trait Tokenizer: Encoder + Decoder {
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
|
||||
}
|
||||
|
||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||
@@ -25,29 +30,45 @@ pub enum Encoding {
|
||||
/// Hugging Face
|
||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||
/// Sentence Piece
|
||||
Sp(Vec<u32>),
|
||||
/// Tiktoken (for GPT models)
|
||||
Tiktoken(Vec<usize>),
|
||||
Sp(Vec<TokenIdType>),
|
||||
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
|
||||
Tiktoken(Vec<TokenIdType>),
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn token_ids(&self) -> Vec<u32> {
|
||||
/// Returns a reference to token IDs when possible, owned Vec for compatibility
|
||||
pub fn token_ids(&self) -> Vec<TokenIdType> {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
||||
Encoding::Sp(inner) => inner.clone(),
|
||||
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
|
||||
Encoding::Tiktoken(inner) => inner.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_ids_ref(&self) -> &[u32] {
|
||||
/// Returns a reference to token IDs where possible
|
||||
pub fn token_ids_ref(&self) -> &[TokenIdType] {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids(),
|
||||
Encoding::Sp(inner) => inner,
|
||||
Encoding::Tiktoken(_) => {
|
||||
// Tiktoken uses usize, we can't return a reference to u32
|
||||
// This is a limitation - callers should use token_ids() for Tiktoken
|
||||
&[]
|
||||
}
|
||||
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a hash of the token IDs for caching purposes
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash implementation for Encoding
|
||||
impl Hash for Encoding {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().hash(state),
|
||||
Encoding::Sp(inner) => inner.hash(state),
|
||||
Encoding::Tiktoken(inner) => inner.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
156
sgl-router/tests/test_chat_template.rs
Normal file
156
sgl-router/tests/test_chat_template.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_chat_message_helpers() {
|
||||
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(system_msg.role, "system");
|
||||
assert_eq!(system_msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
assert_eq!(user_msg.content, "Hello!");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
assert_eq!(assistant_msg.content, "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant"),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("How are you?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_template_with_special_tokens() {
|
||||
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "<s>Hello</s>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_empty_messages() {
|
||||
let template =
|
||||
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
// Integration test with actual tokenizer file loading would go here
|
||||
// but requires a real tokenizer_config.json file
|
||||
}
|
||||
186
sgl-router/tests/test_chat_template_loading.rs
Normal file
186
sgl-router/tests/test_chat_template_loading.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_load_chat_template_from_file() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let template_path = temp_dir.path().join("template.jinja");
|
||||
|
||||
// Write a test template
|
||||
let template_content = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|' + message['role'] + '|>' + message['content'] }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|assistant|>' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
fs::write(&template_path, template_content).unwrap();
|
||||
|
||||
// Create a mock tokenizer config
|
||||
let tokenizer_config = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"hello": 0,
|
||||
"world": 1,
|
||||
"<s>": 2,
|
||||
"</s>": 3
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_config).unwrap();
|
||||
|
||||
// Load tokenizer with custom chat template
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
tokenizer_path.to_str().unwrap(),
|
||||
Some(template_path.to_str().unwrap()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test that the custom template is used
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there"),
|
||||
];
|
||||
|
||||
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Verify the custom template format
|
||||
assert!(result.contains("<|user|>Hello"));
|
||||
assert!(result.contains("<|assistant|>Hi there"));
|
||||
assert!(result.ends_with("<|assistant|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_override_existing_template() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
// Create tokenizer config with a built-in template
|
||||
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
|
||||
let config_with_template = r#"{
|
||||
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
|
||||
}"#;
|
||||
fs::write(&tokenizer_config_path, config_with_template).unwrap();
|
||||
|
||||
// Create the actual tokenizer file
|
||||
let tokenizer_json = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"test": 0,
|
||||
"<s>": 1,
|
||||
"</s>": 2
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||
|
||||
// Create custom template that should override
|
||||
let custom_template_path = temp_dir.path().join("custom.jinja");
|
||||
let custom_template =
|
||||
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
|
||||
fs::write(&custom_template_path, custom_template).unwrap();
|
||||
|
||||
// Load with custom template - should override the built-in one
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
tokenizer_path.to_str().unwrap(),
|
||||
Some(custom_template_path.to_str().unwrap()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
|
||||
// Should use CUSTOM template, not built-in
|
||||
assert!(result.starts_with("CUSTOM:"));
|
||||
assert!(result.contains("[user]: Test"));
|
||||
assert!(!result.contains("built-in:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "huggingface")]
|
||||
fn test_set_chat_template_after_creation() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory and tokenizer file
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tokenizer_json = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"test": 0,
|
||||
"<s>": 1,
|
||||
"</s>": 2
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||
|
||||
// Load tokenizer without custom template
|
||||
let mut tokenizer =
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
|
||||
|
||||
// Set a template after creation (mimics Python's behavior)
|
||||
let new_template =
|
||||
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||
tokenizer.set_chat_template(new_template.to_string());
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
|
||||
assert!(result.starts_with("NEW:"));
|
||||
assert!(result.contains("user: Hello;"));
|
||||
assert!(result.contains("assistant: World;"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user