[router] add tokenizer chat template support (#9370)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,21 +1,36 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
tokenizer: HfTokenizer,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||
Self::from_file_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Self> {
|
||||
let tokenizer = HfTokenizer::from_file(file_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||
|
||||
@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
// Load chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
let chat_template = if let Some(template_path) = chat_template_path {
|
||||
// Load from specified .jinja file
|
||||
Self::load_chat_template_from_file(template_path)?
|
||||
} else {
|
||||
// Try to load from tokenizer_config.json
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to load chat template from tokenizer_config.json
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||
// Try to find tokenizer_config.json in the same directory
|
||||
let path = std::path::Path::new(tokenizer_path);
|
||||
let dir = path.parent()?;
|
||||
let config_path = dir.join("tokenizer_config.json");
|
||||
|
||||
if config_path.exists() {
|
||||
if let Ok(template) =
|
||||
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
|
||||
{
|
||||
return template;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Load chat template from a .jinja file
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(template_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
|
||||
|
||||
// Clean up the template (similar to Python implementation)
|
||||
let template = content.trim().replace("\\n", "\n");
|
||||
|
||||
Ok(Some(template))
|
||||
}
|
||||
|
||||
/// Set or override the chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
||||
// This is a placeholder - actual implementation would handle templates
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.clone(),
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template if available (without minijinja feature)
|
||||
#[cfg(not(feature = "minijinja"))]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
// Fallback to simple formatting
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
|
||||
}
|
||||
|
||||
impl Decoder for HuggingFaceTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||
let start = Instant::now();
|
||||
|
||||
TokenizerMetrics::record_decode_request("huggingface");
|
||||
@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a chat message for template application
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::ChatMessage;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
|
||||
Reference in New Issue
Block a user