[router] add tokenizer chat template support (#9370)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -5,7 +5,7 @@ edition = "2021"
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["huggingface", "grpc-client"]
|
default = ["huggingface", "grpc-client"]
|
||||||
huggingface = ["tokenizers"]
|
huggingface = ["tokenizers", "minijinja"]
|
||||||
tiktoken = ["tiktoken-rs"]
|
tiktoken = ["tiktoken-rs"]
|
||||||
grpc-client = []
|
grpc-client = []
|
||||||
grpc-server = []
|
grpc-server = []
|
||||||
@@ -52,7 +52,8 @@ url = "2.5.4"
|
|||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tokenizers = { version = "0.21.4", optional = true }
|
tokenizers = { version = "0.21.4", optional = true }
|
||||||
tiktoken-rs = { version = "0.5", optional = true }
|
tiktoken-rs = { version = "0.7.0", optional = true }
|
||||||
|
minijinja = { version = "2.0", optional = true }
|
||||||
|
|
||||||
# gRPC and Protobuf dependencies
|
# gRPC and Protobuf dependencies
|
||||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||||
@@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] }
|
|||||||
tower = { version = "0.5", features = ["util"] }
|
tower = { version = "0.5", features = ["util"] }
|
||||||
http-body-util = "0.1"
|
http-body-util = "0.1"
|
||||||
portpicker = "0.1"
|
portpicker = "0.1"
|
||||||
|
tempfile = "3.8"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "request_processing"
|
name = "request_processing"
|
||||||
|
|||||||
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
//! Chat template support for tokenizers using Jinja2 templates
|
||||||
|
//!
|
||||||
|
//! This module provides functionality to apply chat templates to messages,
|
||||||
|
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
use minijinja::{context, Environment, Value};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json;
|
||||||
|
|
||||||
|
/// Represents a chat message with role and content
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatMessage {
|
||||||
|
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
ChatMessage {
|
||||||
|
role: role.into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn system(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("system", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("user", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("assistant", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat template processor using Jinja2
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub struct ChatTemplateProcessor {
|
||||||
|
template: String,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
impl ChatTemplateProcessor {
|
||||||
|
/// Create a new chat template processor
|
||||||
|
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
|
ChatTemplateProcessor {
|
||||||
|
template,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the chat template to a list of messages
|
||||||
|
///
|
||||||
|
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||||
|
/// but returns the formatted string instead of token IDs.
|
||||||
|
pub fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
add_generation_prompt: bool,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
|
||||||
|
// Register the template
|
||||||
|
env.add_template("chat", &self.template)
|
||||||
|
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
||||||
|
|
||||||
|
// Get the template
|
||||||
|
let tmpl = env
|
||||||
|
.get_template("chat")
|
||||||
|
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||||
|
|
||||||
|
// Convert messages to a format Jinja can work with
|
||||||
|
let messages_value: Vec<Value> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| {
|
||||||
|
context! {
|
||||||
|
role => msg.role.clone(),
|
||||||
|
content => msg.content.clone()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Render the template
|
||||||
|
let rendered = tmpl
|
||||||
|
.render(context! {
|
||||||
|
messages => messages_value,
|
||||||
|
add_generation_prompt => add_generation_prompt,
|
||||||
|
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||||
|
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||||
|
})
|
||||||
|
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||||
|
|
||||||
|
Ok(rendered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load chat template from tokenizer config JSON
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let content = fs::read_to_string(config_path)?;
|
||||||
|
let config: serde_json::Value = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
// Look for chat_template in the config
|
||||||
|
if let Some(template) = config.get("chat_template") {
|
||||||
|
if let Some(template_str) = template.as_str() {
|
||||||
|
return Ok(Some(template_str.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_creation() {
|
||||||
|
let msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
assert_eq!(msg.role, "system");
|
||||||
|
assert_eq!(msg.content, "You are a helpful assistant");
|
||||||
|
|
||||||
|
let user_msg = ChatMessage::user("Hello!");
|
||||||
|
assert_eq!(user_msg.role, "user");
|
||||||
|
|
||||||
|
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||||
|
assert_eq!(assistant_msg.role, "assistant");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
#[test]
|
||||||
|
fn test_simple_chat_template() {
|
||||||
|
// Simple template that formats messages
|
||||||
|
let template = r#"
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{{ message.role }}: {{ message.content }}
|
||||||
|
{% endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
assistant:
|
||||||
|
{%- endif -%}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("You are helpful"),
|
||||||
|
ChatMessage::user("Hello"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||||
|
assert!(result.contains("system: You are helpful"));
|
||||||
|
assert!(result.contains("user: Hello"));
|
||||||
|
assert!(result.contains("assistant:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_with_tokens() {
|
||||||
|
// Template that uses special tokens
|
||||||
|
let template = r#"
|
||||||
|
{{ bos_token }}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||||
|
{% endfor -%}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(
|
||||||
|
template.to_string(),
|
||||||
|
Some("<s>".to_string()),
|
||||||
|
Some("</s>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("Test")];
|
||||||
|
|
||||||
|
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||||
|
assert!(result.contains("<s>"));
|
||||||
|
assert!(result.contains("</s>"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,14 @@ pub enum TokenizerType {
|
|||||||
/// - json: HuggingFace tokenizer
|
/// - json: HuggingFace tokenizer
|
||||||
/// - For testing: can return mock tokenizer
|
/// - For testing: can return mock tokenizer
|
||||||
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
|
create_tokenizer_with_chat_template(file_path, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a tokenizer from a file path with an optional chat template
|
||||||
|
pub fn create_tokenizer_with_chat_template(
|
||||||
|
file_path: &str,
|
||||||
|
chat_template_path: Option<&str>,
|
||||||
|
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
// Special case for testing
|
// Special case for testing
|
||||||
@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
|
|||||||
Some("json") => {
|
Some("json") => {
|
||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
{
|
{
|
||||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||||
|
file_path,
|
||||||
|
chat_template_path,
|
||||||
|
)?;
|
||||||
|
|
||||||
TokenizerMetrics::record_factory_load("json");
|
TokenizerMetrics::record_factory_load("json");
|
||||||
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
|
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
|
||||||
|
|||||||
@@ -1,21 +1,36 @@
|
|||||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
use super::traits::{
|
||||||
|
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||||
|
};
|
||||||
use crate::metrics::TokenizerMetrics;
|
use crate::metrics::TokenizerMetrics;
|
||||||
use anyhow::{Error, Result};
|
use anyhow::{Error, Result};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||||
|
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||||
|
|
||||||
/// HuggingFace tokenizer wrapper
|
/// HuggingFace tokenizer wrapper
|
||||||
pub struct HuggingFaceTokenizer {
|
pub struct HuggingFaceTokenizer {
|
||||||
tokenizer: HfTokenizer,
|
tokenizer: HfTokenizer,
|
||||||
special_tokens: SpecialTokens,
|
special_tokens: SpecialTokens,
|
||||||
vocab: HashMap<String, u32>,
|
vocab: HashMap<String, TokenIdType>,
|
||||||
reverse_vocab: HashMap<u32, String>,
|
reverse_vocab: HashMap<TokenIdType, String>,
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HuggingFaceTokenizer {
|
impl HuggingFaceTokenizer {
|
||||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||||
|
Self::from_file_with_chat_template(file_path, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
|
||||||
|
pub fn from_file_with_chat_template(
|
||||||
|
file_path: &str,
|
||||||
|
chat_template_path: Option<&str>,
|
||||||
|
) -> Result<Self> {
|
||||||
let tokenizer = HfTokenizer::from_file(file_path)
|
let tokenizer = HfTokenizer::from_file(file_path)
|
||||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||||
|
|
||||||
@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
|
|||||||
|
|
||||||
// Build vocab mappings
|
// Build vocab mappings
|
||||||
let vocab = tokenizer.get_vocab(false);
|
let vocab = tokenizer.get_vocab(false);
|
||||||
let reverse_vocab: HashMap<u32, String> = vocab
|
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(token, &id)| (id, token.clone()))
|
.map(|(token, &id)| (id, token.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Load chat template
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
let chat_template = if let Some(template_path) = chat_template_path {
|
||||||
|
// Load from specified .jinja file
|
||||||
|
Self::load_chat_template_from_file(template_path)?
|
||||||
|
} else {
|
||||||
|
// Try to load from tokenizer_config.json
|
||||||
|
Self::load_chat_template(file_path)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(HuggingFaceTokenizer {
|
Ok(HuggingFaceTokenizer {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
special_tokens,
|
special_tokens,
|
||||||
vocab,
|
vocab,
|
||||||
reverse_vocab,
|
reverse_vocab,
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
chat_template,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
|
|||||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||||
let vocab = tokenizer.get_vocab(false);
|
let vocab = tokenizer.get_vocab(false);
|
||||||
let reverse_vocab: HashMap<u32, String> = vocab
|
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(token, &id)| (id, token.clone()))
|
.map(|(token, &id)| (id, token.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
|
|||||||
special_tokens,
|
special_tokens,
|
||||||
vocab,
|
vocab,
|
||||||
reverse_vocab,
|
reverse_vocab,
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
chat_template: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Try to load chat template from tokenizer_config.json
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||||
|
// Try to find tokenizer_config.json in the same directory
|
||||||
|
let path = std::path::Path::new(tokenizer_path);
|
||||||
|
let dir = path.parent()?;
|
||||||
|
let config_path = dir.join("tokenizer_config.json");
|
||||||
|
|
||||||
|
if config_path.exists() {
|
||||||
|
if let Ok(template) =
|
||||||
|
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
|
||||||
|
{
|
||||||
|
return template;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load chat template from a .jinja file
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let content = fs::read_to_string(template_path)
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
|
||||||
|
|
||||||
|
// Clean up the template (similar to Python implementation)
|
||||||
|
let template = content.trim().replace("\\n", "\n");
|
||||||
|
|
||||||
|
Ok(Some(template))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set or override the chat template
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
|
pub fn set_chat_template(&mut self, template: String) {
|
||||||
|
self.chat_template = Some(template);
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply chat template if available
|
/// Apply chat template if available
|
||||||
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
#[cfg(feature = "minijinja")]
|
||||||
// This is a placeholder - actual implementation would handle templates
|
pub fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
add_generation_prompt: bool,
|
||||||
|
) -> Result<String> {
|
||||||
|
if let Some(ref template) = self.chat_template {
|
||||||
|
let processor = ChatTemplateProcessor::new(
|
||||||
|
template.clone(),
|
||||||
|
self.special_tokens.bos_token.clone(),
|
||||||
|
self.special_tokens.eos_token.clone(),
|
||||||
|
);
|
||||||
|
processor.apply_chat_template(messages, add_generation_prompt)
|
||||||
|
} else {
|
||||||
|
// Fallback to simple formatting if no template is available
|
||||||
|
let mut result = String::new();
|
||||||
|
for msg in messages {
|
||||||
|
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||||
|
}
|
||||||
|
if add_generation_prompt {
|
||||||
|
result.push_str("assistant: ");
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply chat template if available (without minijinja feature)
|
||||||
|
#[cfg(not(feature = "minijinja"))]
|
||||||
|
pub fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
add_generation_prompt: bool,
|
||||||
|
) -> Result<String> {
|
||||||
|
// Fallback to simple formatting
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
for msg in messages {
|
for msg in messages {
|
||||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||||
}
|
}
|
||||||
|
if add_generation_prompt {
|
||||||
|
result.push_str("assistant: ");
|
||||||
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder for HuggingFaceTokenizer {
|
impl Decoder for HuggingFaceTokenizer {
|
||||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
TokenizerMetrics::record_decode_request("huggingface");
|
TokenizerMetrics::record_decode_request("huggingface");
|
||||||
@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
|||||||
&self.special_tokens
|
&self.special_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||||
self.vocab.get(token).copied()
|
self.vocab.get(token).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||||
self.reverse_vocab.get(&id).cloned()
|
self.reverse_vocab.get(&id).cloned()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a chat message for template application
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ChatMessage {
|
|
||||||
pub role: String,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatMessage {
|
|
||||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
|
||||||
ChatMessage {
|
|
||||||
role: role.into(),
|
|
||||||
content: content.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn system(content: impl Into<String>) -> Self {
|
|
||||||
Self::new("system", content)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn user(content: impl Into<String>) -> Self {
|
|
||||||
Self::new("user", content)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn assistant(content: impl Into<String>) -> Self {
|
|
||||||
Self::new("assistant", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
#[cfg(feature = "minijinja")]
|
||||||
|
use super::ChatMessage;
|
||||||
|
|
||||||
|
#[cfg(feature = "minijinja")]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_message_creation() {
|
fn test_chat_message_creation() {
|
||||||
let msg = ChatMessage::system("You are a helpful assistant");
|
let msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ pub mod stream;
|
|||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
// Feature-gated modules
|
// Feature-gated modules
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub mod chat_template;
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
pub mod huggingface;
|
pub mod huggingface;
|
||||||
|
|
||||||
@@ -20,14 +23,20 @@ pub mod tiktoken;
|
|||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
// Re-exports
|
// Re-exports
|
||||||
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
pub use factory::{
|
||||||
|
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
||||||
|
TokenizerType,
|
||||||
|
};
|
||||||
pub use sequence::Sequence;
|
pub use sequence::Sequence;
|
||||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||||
pub use stream::DecodeStream;
|
pub use stream::DecodeStream;
|
||||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
pub use huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub use chat_template::ChatMessage;
|
||||||
|
|
||||||
#[cfg(feature = "tiktoken")]
|
#[cfg(feature = "tiktoken")]
|
||||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||||
@@ -42,6 +51,17 @@ impl Tokenizer {
|
|||||||
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a tokenizer from a file path with an optional chat template
|
||||||
|
pub fn from_file_with_chat_template(
|
||||||
|
file_path: &str,
|
||||||
|
chat_template_path: Option<&str>,
|
||||||
|
) -> Result<Tokenizer> {
|
||||||
|
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
|
||||||
|
file_path,
|
||||||
|
chat_template_path,
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||||
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||||
Tokenizer(tokenizer)
|
Tokenizer(tokenizer)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::traits::Tokenizer as TokenizerTrait;
|
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -9,7 +9,7 @@ pub struct Sequence {
|
|||||||
tokenizer: Arc<dyn TokenizerTrait>,
|
tokenizer: Arc<dyn TokenizerTrait>,
|
||||||
|
|
||||||
/// The current sequence of token ids
|
/// The current sequence of token ids
|
||||||
token_ids: Vec<u32>,
|
token_ids: Vec<TokenIdType>,
|
||||||
|
|
||||||
/// The position in the current sequence the last decoded token completed
|
/// The position in the current sequence the last decoded token completed
|
||||||
prefix_offset: usize,
|
prefix_offset: usize,
|
||||||
@@ -54,7 +54,7 @@ impl Sequence {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a sequence with initial tokens
|
/// Create a sequence with initial tokens
|
||||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
|
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||||
let len = token_ids.len();
|
let len = token_ids.len();
|
||||||
Self {
|
Self {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -90,7 +90,7 @@ impl Sequence {
|
|||||||
|
|
||||||
/// Append a single token to the sequence and return newly decoded text
|
/// Append a single token to the sequence and return newly decoded text
|
||||||
/// Based on HuggingFace TGI incremental decoding
|
/// Based on HuggingFace TGI incremental decoding
|
||||||
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
|
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
|
||||||
// Store the old read offset before adding the new token
|
// Store the old read offset before adding the new token
|
||||||
let old_read_offset = self.read_offset;
|
let old_read_offset = self.read_offset;
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ impl Sequence {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the current token ids
|
/// Get the current token ids
|
||||||
pub fn token_ids(&self) -> &[u32] {
|
pub fn token_ids(&self) -> &[TokenIdType] {
|
||||||
&self.token_ids
|
&self.token_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::traits;
|
use super::traits::{self, TokenIdType};
|
||||||
use crate::metrics::TokenizerMetrics;
|
use crate::metrics::TokenizerMetrics;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput {
|
|||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct StopSequenceConfig {
|
pub struct StopSequenceConfig {
|
||||||
/// Token IDs that trigger a stop
|
/// Token IDs that trigger a stop
|
||||||
pub stop_tokens: HashSet<u32>,
|
pub stop_tokens: HashSet<TokenIdType>,
|
||||||
/// String sequences that trigger a stop
|
/// String sequences that trigger a stop
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
/// Token IDs for visible stops (included in output)
|
/// Token IDs for visible stops (included in output)
|
||||||
pub visible_stop_tokens: HashSet<u32>,
|
pub visible_stop_tokens: HashSet<TokenIdType>,
|
||||||
/// String sequences for visible stops (included in output)
|
/// String sequences for visible stops (included in output)
|
||||||
pub visible_stop_sequences: Vec<String>,
|
pub visible_stop_sequences: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StopSequenceConfig {
|
impl StopSequenceConfig {
|
||||||
/// Builder pattern - add a stop token
|
/// Builder pattern - add a stop token
|
||||||
pub fn with_stop_token(mut self, token_id: u32) -> Self {
|
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||||
self.stop_tokens.insert(token_id);
|
self.stop_tokens.insert(token_id);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@@ -45,7 +45,7 @@ impl StopSequenceConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Builder pattern - add a visible stop token
|
/// Builder pattern - add a visible stop token
|
||||||
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
|
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||||
self.visible_stop_tokens.insert(token_id);
|
self.visible_stop_tokens.insert(token_id);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder {
|
|||||||
/// Buffer for partial matches (the "jail")
|
/// Buffer for partial matches (the "jail")
|
||||||
jail_buffer: String,
|
jail_buffer: String,
|
||||||
/// Accumulated tokens
|
/// Accumulated tokens
|
||||||
token_buffer: Vec<u32>,
|
token_buffer: Vec<TokenIdType>,
|
||||||
/// Offset where the prefix text starts (for context)
|
/// Offset where the prefix text starts (for context)
|
||||||
prefix_offset: usize,
|
prefix_offset: usize,
|
||||||
/// Offset marking the end of previously decoded text
|
/// Offset marking the end of previously decoded text
|
||||||
@@ -94,7 +94,7 @@ impl StopSequenceDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Process a single token
|
/// Process a single token
|
||||||
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
|
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
if self.stopped {
|
if self.stopped {
|
||||||
@@ -252,7 +252,10 @@ impl StopSequenceDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Process multiple tokens
|
/// Process multiple tokens
|
||||||
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
|
pub fn process_tokens(
|
||||||
|
&mut self,
|
||||||
|
token_ids: &[TokenIdType],
|
||||||
|
) -> Result<Vec<SequenceDecoderOutput>> {
|
||||||
let mut outputs = Vec::new();
|
let mut outputs = Vec::new();
|
||||||
for &token_id in token_ids {
|
for &token_id in token_ids {
|
||||||
outputs.push(self.process_token(token_id)?);
|
outputs.push(self.process_token(token_id)?);
|
||||||
@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stop_token(mut self, token_id: u32) -> Self {
|
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||||
self.config.stop_tokens.insert(token_id);
|
self.config.stop_tokens.insert(token_id);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
|
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||||
self.config.visible_stop_tokens.insert(token_id);
|
self.config.visible_stop_tokens.insert(token_id);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// src/tokenizer/stream.rs
|
// src/tokenizer/stream.rs
|
||||||
|
|
||||||
use super::traits;
|
use super::traits::{self, TokenIdType};
|
||||||
use crate::metrics::TokenizerMetrics;
|
use crate::metrics::TokenizerMetrics;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -18,7 +18,7 @@ pub struct DecodeStream {
|
|||||||
|
|
||||||
/// A temporary buffer of the necessary token_ids needed
|
/// A temporary buffer of the necessary token_ids needed
|
||||||
/// to produce valid string chunks
|
/// to produce valid string chunks
|
||||||
all_token_ids: Vec<u32>,
|
all_token_ids: Vec<TokenIdType>,
|
||||||
|
|
||||||
prefix_offset: usize,
|
prefix_offset: usize,
|
||||||
read_offset: usize,
|
read_offset: usize,
|
||||||
@@ -27,7 +27,7 @@ pub struct DecodeStream {
|
|||||||
impl DecodeStream {
|
impl DecodeStream {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||||
prompt_token_ids: &[u32],
|
prompt_token_ids: &[TokenIdType],
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let num_input_tokens = prompt_token_ids.len();
|
let num_input_tokens = prompt_token_ids.len();
|
||||||
@@ -44,7 +44,7 @@ impl DecodeStream {
|
|||||||
|
|
||||||
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
||||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
self.all_token_ids.push(id);
|
self.all_token_ids.push(id);
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
use super::traits::{
|
||||||
|
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||||
|
};
|
||||||
use anyhow::{Error, Result};
|
use anyhow::{Error, Result};
|
||||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||||
|
|
||||||
@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder for TiktokenTokenizer {
|
impl Decoder for TiktokenTokenizer {
|
||||||
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
|
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
|
||||||
// Convert u32 to usize for tiktoken-rs
|
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
|
||||||
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
|
|
||||||
|
|
||||||
self.tokenizer
|
self.tokenizer
|
||||||
.decode(tokens)
|
.decode(token_ids.to_vec())
|
||||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer {
|
|||||||
&self.special_tokens
|
&self.special_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_to_id(&self, _token: &str) -> Option<u32> {
|
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
|
||||||
// Tiktoken doesn't provide direct token-to-id mapping
|
// Tiktoken doesn't provide direct token-to-id mapping
|
||||||
// We'd need to encode the token and check if it produces a single ID
|
// We'd need to encode the token and check if it produces a single ID
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, _id: u32) -> Option<String> {
|
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
|
||||||
// Tiktoken doesn't provide direct id-to-token mapping
|
// Tiktoken doesn't provide direct id-to-token mapping
|
||||||
// We can only decode IDs to text
|
// We can only decode IDs to text
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
|
||||||
|
/// Type alias for token IDs
|
||||||
|
pub type TokenIdType = u32;
|
||||||
|
|
||||||
/// Core encoding trait - separate from decoding for modularity
|
/// Core encoding trait - separate from decoding for modularity
|
||||||
pub trait Encoder: Send + Sync {
|
pub trait Encoder: Send + Sync {
|
||||||
@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
|
|||||||
|
|
||||||
/// Core decoding trait - can be implemented independently
|
/// Core decoding trait - can be implemented independently
|
||||||
pub trait Decoder: Send + Sync {
|
pub trait Decoder: Send + Sync {
|
||||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
|
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Combined tokenizer trait
|
/// Combined tokenizer trait
|
||||||
pub trait Tokenizer: Encoder + Decoder {
|
pub trait Tokenizer: Encoder + Decoder {
|
||||||
fn vocab_size(&self) -> usize;
|
fn vocab_size(&self) -> usize;
|
||||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
|
||||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||||
@@ -25,29 +30,45 @@ pub enum Encoding {
|
|||||||
/// Hugging Face
|
/// Hugging Face
|
||||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||||
/// Sentence Piece
|
/// Sentence Piece
|
||||||
Sp(Vec<u32>),
|
Sp(Vec<TokenIdType>),
|
||||||
/// Tiktoken (for GPT models)
|
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
|
||||||
Tiktoken(Vec<usize>),
|
Tiktoken(Vec<TokenIdType>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
pub fn token_ids(&self) -> Vec<u32> {
|
/// Returns a reference to token IDs when possible, owned Vec for compatibility
|
||||||
|
pub fn token_ids(&self) -> Vec<TokenIdType> {
|
||||||
match self {
|
match self {
|
||||||
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
||||||
Encoding::Sp(inner) => inner.clone(),
|
Encoding::Sp(inner) => inner.clone(),
|
||||||
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
|
Encoding::Tiktoken(inner) => inner.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn token_ids_ref(&self) -> &[u32] {
|
/// Returns a reference to token IDs where possible
|
||||||
|
pub fn token_ids_ref(&self) -> &[TokenIdType] {
|
||||||
match self {
|
match self {
|
||||||
Encoding::Hf(inner) => inner.get_ids(),
|
Encoding::Hf(inner) => inner.get_ids(),
|
||||||
Encoding::Sp(inner) => inner,
|
Encoding::Sp(inner) => inner,
|
||||||
Encoding::Tiktoken(_) => {
|
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
|
||||||
// Tiktoken uses usize, we can't return a reference to u32
|
}
|
||||||
// This is a limitation - callers should use token_ids() for Tiktoken
|
}
|
||||||
&[]
|
|
||||||
}
|
/// Get a hash of the token IDs for caching purposes
|
||||||
|
pub fn get_hash(&self) -> u64 {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
self.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hash implementation for Encoding
|
||||||
|
impl Hash for Encoding {
|
||||||
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
|
match self {
|
||||||
|
Encoding::Hf(inner) => inner.get_ids().hash(state),
|
||||||
|
Encoding::Sp(inner) => inner.hash(state),
|
||||||
|
Encoding::Tiktoken(inner) => inner.hash(state),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
156
sgl-router/tests/test_chat_template.rs
Normal file
156
sgl-router/tests/test_chat_template.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_chat_message_helpers() {
|
||||||
|
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
assert_eq!(system_msg.role, "system");
|
||||||
|
assert_eq!(system_msg.content, "You are a helpful assistant");
|
||||||
|
|
||||||
|
let user_msg = ChatMessage::user("Hello!");
|
||||||
|
assert_eq!(user_msg.role, "user");
|
||||||
|
assert_eq!(user_msg.content, "Hello!");
|
||||||
|
|
||||||
|
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||||
|
assert_eq!(assistant_msg.role, "assistant");
|
||||||
|
assert_eq!(assistant_msg.content, "Hi there!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_llama_style_template() {
|
||||||
|
// Test a Llama-style chat template
|
||||||
|
let template = r#"
|
||||||
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
|
{%- set system_message = messages[0]['content'] -%}
|
||||||
|
{%- set messages = messages[1:] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set system_message = '' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{{- bos_token }}
|
||||||
|
{%- if system_message %}
|
||||||
|
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(
|
||||||
|
template.to_string(),
|
||||||
|
Some("<|begin_of_text|>".to_string()),
|
||||||
|
Some("<|end_of_text|>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("You are a helpful assistant"),
|
||||||
|
ChatMessage::user("What is 2+2?"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||||
|
|
||||||
|
// Check that the result contains expected markers
|
||||||
|
assert!(result.contains("<|begin_of_text|>"));
|
||||||
|
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||||
|
assert!(result.contains("You are a helpful assistant"));
|
||||||
|
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||||
|
assert!(result.contains("What is 2+2?"));
|
||||||
|
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_chatml_template() {
|
||||||
|
// Test a ChatML-style template
|
||||||
|
let template = r#"
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::user("Hello"),
|
||||||
|
ChatMessage::assistant("Hi there!"),
|
||||||
|
ChatMessage::user("How are you?"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||||
|
|
||||||
|
// Check ChatML format
|
||||||
|
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||||
|
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||||
|
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||||
|
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_template_without_generation_prompt() {
|
||||||
|
let template = r#"
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{{ message.role }}: {{ message.content }}
|
||||||
|
{% endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
assistant:
|
||||||
|
{%- endif -%}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("Test")];
|
||||||
|
|
||||||
|
// Test without generation prompt
|
||||||
|
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||||
|
assert_eq!(result.trim(), "user: Test");
|
||||||
|
|
||||||
|
// Test with generation prompt
|
||||||
|
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
|
||||||
|
assert!(result_with_prompt.contains("assistant:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_template_with_special_tokens() {
|
||||||
|
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(
|
||||||
|
template.to_string(),
|
||||||
|
Some("<s>".to_string()),
|
||||||
|
Some("</s>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("Hello")];
|
||||||
|
|
||||||
|
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||||
|
assert_eq!(result, "<s>Hello</s>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_empty_messages() {
|
||||||
|
let template =
|
||||||
|
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||||
|
|
||||||
|
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||||
|
|
||||||
|
let messages = vec![];
|
||||||
|
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||||
|
assert_eq!(result, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integration test with actual tokenizer file loading would go here
|
||||||
|
// but requires a real tokenizer_config.json file
|
||||||
|
}
|
||||||
186
sgl-router/tests/test_chat_template_loading.rs
Normal file
186
sgl-router/tests/test_chat_template_loading.rs
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_load_chat_template_from_file() {
|
||||||
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
|
// Create temporary directory
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let template_path = temp_dir.path().join("template.jinja");
|
||||||
|
|
||||||
|
// Write a test template
|
||||||
|
let template_content = r#"
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- '<|' + message['role'] + '|>' + message['content'] }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|assistant|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
"#;
|
||||||
|
fs::write(&template_path, template_content).unwrap();
|
||||||
|
|
||||||
|
// Create a mock tokenizer config
|
||||||
|
let tokenizer_config = r#"{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens": [],
|
||||||
|
"normalizer": null,
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Whitespace"
|
||||||
|
},
|
||||||
|
"post_processor": null,
|
||||||
|
"decoder": null,
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": {
|
||||||
|
"hello": 0,
|
||||||
|
"world": 1,
|
||||||
|
"<s>": 2,
|
||||||
|
"</s>": 3
|
||||||
|
},
|
||||||
|
"merges": []
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||||
|
fs::write(&tokenizer_path, tokenizer_config).unwrap();
|
||||||
|
|
||||||
|
// Load tokenizer with custom chat template
|
||||||
|
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||||
|
tokenizer_path.to_str().unwrap(),
|
||||||
|
Some(template_path.to_str().unwrap()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Test that the custom template is used
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::user("Hello"),
|
||||||
|
ChatMessage::assistant("Hi there"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
|
||||||
|
|
||||||
|
// Verify the custom template format
|
||||||
|
assert!(result.contains("<|user|>Hello"));
|
||||||
|
assert!(result.contains("<|assistant|>Hi there"));
|
||||||
|
assert!(result.ends_with("<|assistant|>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_override_existing_template() {
|
||||||
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
|
// Create temporary directory
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create tokenizer config with a built-in template
|
||||||
|
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
|
||||||
|
let config_with_template = r#"{
|
||||||
|
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
|
||||||
|
}"#;
|
||||||
|
fs::write(&tokenizer_config_path, config_with_template).unwrap();
|
||||||
|
|
||||||
|
// Create the actual tokenizer file
|
||||||
|
let tokenizer_json = r#"{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens": [],
|
||||||
|
"normalizer": null,
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Whitespace"
|
||||||
|
},
|
||||||
|
"post_processor": null,
|
||||||
|
"decoder": null,
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": {
|
||||||
|
"test": 0,
|
||||||
|
"<s>": 1,
|
||||||
|
"</s>": 2
|
||||||
|
},
|
||||||
|
"merges": []
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||||
|
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||||
|
|
||||||
|
// Create custom template that should override
|
||||||
|
let custom_template_path = temp_dir.path().join("custom.jinja");
|
||||||
|
let custom_template =
|
||||||
|
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
|
||||||
|
fs::write(&custom_template_path, custom_template).unwrap();
|
||||||
|
|
||||||
|
// Load with custom template - should override the built-in one
|
||||||
|
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||||
|
tokenizer_path.to_str().unwrap(),
|
||||||
|
Some(custom_template_path.to_str().unwrap()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("Test")];
|
||||||
|
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||||
|
|
||||||
|
// Should use CUSTOM template, not built-in
|
||||||
|
assert!(result.starts_with("CUSTOM:"));
|
||||||
|
assert!(result.contains("[user]: Test"));
|
||||||
|
assert!(!result.contains("built-in:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
fn test_set_chat_template_after_creation() {
|
||||||
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
|
// Create temporary directory and tokenizer file
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let tokenizer_json = r#"{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens": [],
|
||||||
|
"normalizer": null,
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Whitespace"
|
||||||
|
},
|
||||||
|
"post_processor": null,
|
||||||
|
"decoder": null,
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": {
|
||||||
|
"test": 0,
|
||||||
|
"<s>": 1,
|
||||||
|
"</s>": 2
|
||||||
|
},
|
||||||
|
"merges": []
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||||
|
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||||
|
|
||||||
|
// Load tokenizer without custom template
|
||||||
|
let mut tokenizer =
|
||||||
|
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
|
||||||
|
|
||||||
|
// Set a template after creation (mimics Python's behavior)
|
||||||
|
let new_template =
|
||||||
|
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||||
|
tokenizer.set_chat_template(new_template.to_string());
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
|
||||||
|
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||||
|
|
||||||
|
assert!(result.starts_with("NEW:"));
|
||||||
|
assert!(result.contains("user: Hello;"));
|
||||||
|
assert!(result.contains("assistant: World;"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user