[router] add tokenizer chat template support (#9370)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
188
sgl-router/src/tokenizer/chat_template.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Chat template support for tokenizers using Jinja2 templates
|
||||
//!
|
||||
//! This module provides functionality to apply chat templates to messages,
|
||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
#[cfg(feature = "huggingface")]
|
||||
use minijinja::{context, Environment, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
/// Represents a chat message with role and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat template processor using Jinja2
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub struct ChatTemplateProcessor {
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
impl ChatTemplateProcessor {
|
||||
/// Create a new chat template processor
|
||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
ChatTemplateProcessor {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the chat template to a list of messages
|
||||
///
|
||||
/// This mimics the behavior of HuggingFace's apply_chat_template method
|
||||
/// but returns the formatted string instead of token IDs.
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
// Register the template
|
||||
env.add_template("chat", &self.template)
|
||||
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
|
||||
|
||||
// Get the template
|
||||
let tmpl = env
|
||||
.get_template("chat")
|
||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||
|
||||
// Convert messages to a format Jinja can work with
|
||||
let messages_value: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
context! {
|
||||
role => msg.role.clone(),
|
||||
content => msg.content.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Render the template
|
||||
let rendered = tmpl
|
||||
.render(context! {
|
||||
messages => messages_value,
|
||||
add_generation_prompt => add_generation_prompt,
|
||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
||||
})
|
||||
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||
|
||||
Ok(rendered)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load chat template from tokenizer config JSON
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(config_path)?;
|
||||
let config: serde_json::Value = serde_json::from_str(&content)?;
|
||||
|
||||
// Look for chat_template in the config
|
||||
if let Some(template) = config.get("chat_template") {
|
||||
if let Some(template_str) = template.as_str() {
|
||||
return Ok(Some(template_str.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
// Simple template that formats messages
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,14 @@ pub enum TokenizerType {
|
||||
/// - json: HuggingFace tokenizer
|
||||
/// - For testing: can return mock tokenizer
|
||||
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
create_tokenizer_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn create_tokenizer_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Special case for testing
|
||||
@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
|
||||
Some("json") => {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?;
|
||||
|
||||
TokenizerMetrics::record_factory_load("json");
|
||||
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
tokenizer: HfTokenizer,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
vocab: HashMap<String, TokenIdType>,
|
||||
reverse_vocab: HashMap<TokenIdType, String>,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||
Self::from_file_with_chat_template(file_path, None)
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Self> {
|
||||
let tokenizer = HfTokenizer::from_file(file_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||
|
||||
@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
// Load chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
let chat_template = if let Some(template_path) = chat_template_path {
|
||||
// Load from specified .jinja file
|
||||
Self::load_chat_template_from_file(template_path)?
|
||||
} else {
|
||||
// Try to load from tokenizer_config.json
|
||||
Self::load_chat_template(file_path)
|
||||
};
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
#[cfg(feature = "minijinja")]
|
||||
chat_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to load chat template from tokenizer_config.json
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||
// Try to find tokenizer_config.json in the same directory
|
||||
let path = std::path::Path::new(tokenizer_path);
|
||||
let dir = path.parent()?;
|
||||
let config_path = dir.join("tokenizer_config.json");
|
||||
|
||||
if config_path.exists() {
|
||||
if let Ok(template) =
|
||||
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
|
||||
{
|
||||
return template;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Load chat template from a .jinja file
|
||||
#[cfg(feature = "minijinja")]
|
||||
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||
use std::fs;
|
||||
|
||||
let content = fs::read_to_string(template_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
|
||||
|
||||
// Clean up the template (similar to Python implementation)
|
||||
let template = content.trim().replace("\\n", "\n");
|
||||
|
||||
Ok(Some(template))
|
||||
}
|
||||
|
||||
/// Set or override the chat template
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn set_chat_template(&mut self, template: String) {
|
||||
self.chat_template = Some(template);
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
||||
// This is a placeholder - actual implementation would handle templates
|
||||
#[cfg(feature = "minijinja")]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
if let Some(ref template) = self.chat_template {
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.clone(),
|
||||
self.special_tokens.bos_token.clone(),
|
||||
self.special_tokens.eos_token.clone(),
|
||||
);
|
||||
processor.apply_chat_template(messages, add_generation_prompt)
|
||||
} else {
|
||||
// Fallback to simple formatting if no template is available
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template if available (without minijinja feature)
|
||||
#[cfg(not(feature = "minijinja"))]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
add_generation_prompt: bool,
|
||||
) -> Result<String> {
|
||||
// Fallback to simple formatting
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
if add_generation_prompt {
|
||||
result.push_str("assistant: ");
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
|
||||
}
|
||||
|
||||
impl Decoder for HuggingFaceTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||
let start = Instant::now();
|
||||
|
||||
TokenizerMetrics::record_decode_request("huggingface");
|
||||
@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a chat message for template application
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[cfg(feature = "minijinja")]
|
||||
use super::ChatMessage;
|
||||
|
||||
#[cfg(feature = "minijinja")]
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
|
||||
@@ -10,6 +10,9 @@ pub mod stream;
|
||||
pub mod traits;
|
||||
|
||||
// Feature-gated modules
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub mod chat_template;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub mod huggingface;
|
||||
|
||||
@@ -20,14 +23,20 @@ pub mod tiktoken;
|
||||
mod tests;
|
||||
|
||||
// Re-exports
|
||||
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
||||
pub use factory::{
|
||||
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
||||
TokenizerType,
|
||||
};
|
||||
pub use sequence::Sequence;
|
||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
||||
pub use huggingface::HuggingFaceTokenizer;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub use chat_template::ChatMessage;
|
||||
|
||||
#[cfg(feature = "tiktoken")]
|
||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||
@@ -42,6 +51,17 @@ impl Tokenizer {
|
||||
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path with an optional chat template
|
||||
pub fn from_file_with_chat_template(
|
||||
file_path: &str,
|
||||
chat_template_path: Option<&str>,
|
||||
) -> Result<Tokenizer> {
|
||||
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::traits::Tokenizer as TokenizerTrait;
|
||||
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -9,7 +9,7 @@ pub struct Sequence {
|
||||
tokenizer: Arc<dyn TokenizerTrait>,
|
||||
|
||||
/// The current sequence of token ids
|
||||
token_ids: Vec<u32>,
|
||||
token_ids: Vec<TokenIdType>,
|
||||
|
||||
/// The position in the current sequence the last decoded token completed
|
||||
prefix_offset: usize,
|
||||
@@ -54,7 +54,7 @@ impl Sequence {
|
||||
}
|
||||
|
||||
/// Create a sequence with initial tokens
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||
let len = token_ids.len();
|
||||
Self {
|
||||
tokenizer,
|
||||
@@ -90,7 +90,7 @@ impl Sequence {
|
||||
|
||||
/// Append a single token to the sequence and return newly decoded text
|
||||
/// Based on HuggingFace TGI incremental decoding
|
||||
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
|
||||
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
|
||||
// Store the old read offset before adding the new token
|
||||
let old_read_offset = self.read_offset;
|
||||
|
||||
@@ -145,7 +145,7 @@ impl Sequence {
|
||||
}
|
||||
|
||||
/// Get the current token ids
|
||||
pub fn token_ids(&self) -> &[u32] {
|
||||
pub fn token_ids(&self) -> &[TokenIdType] {
|
||||
&self.token_ids
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::traits;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StopSequenceConfig {
|
||||
/// Token IDs that trigger a stop
|
||||
pub stop_tokens: HashSet<u32>,
|
||||
pub stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences that trigger a stop
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// Token IDs for visible stops (included in output)
|
||||
pub visible_stop_tokens: HashSet<u32>,
|
||||
pub visible_stop_tokens: HashSet<TokenIdType>,
|
||||
/// String sequences for visible stops (included in output)
|
||||
pub visible_stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
impl StopSequenceConfig {
|
||||
/// Builder pattern - add a stop token
|
||||
pub fn with_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -45,7 +45,7 @@ impl StopSequenceConfig {
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop token
|
||||
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder {
|
||||
/// Buffer for partial matches (the "jail")
|
||||
jail_buffer: String,
|
||||
/// Accumulated tokens
|
||||
token_buffer: Vec<u32>,
|
||||
token_buffer: Vec<TokenIdType>,
|
||||
/// Offset where the prefix text starts (for context)
|
||||
prefix_offset: usize,
|
||||
/// Offset marking the end of previously decoded text
|
||||
@@ -94,7 +94,7 @@ impl StopSequenceDecoder {
|
||||
}
|
||||
|
||||
/// Process a single token
|
||||
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
|
||||
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
if self.stopped {
|
||||
@@ -252,7 +252,10 @@ impl StopSequenceDecoder {
|
||||
}
|
||||
|
||||
/// Process multiple tokens
|
||||
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
pub fn process_tokens(
|
||||
&mut self,
|
||||
token_ids: &[TokenIdType],
|
||||
) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
let mut outputs = Vec::new();
|
||||
for &token_id in token_ids {
|
||||
outputs.push(self.process_token(token_id)?);
|
||||
@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||||
self.config.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// src/tokenizer/stream.rs
|
||||
|
||||
use super::traits;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
@@ -18,7 +18,7 @@ pub struct DecodeStream {
|
||||
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks
|
||||
all_token_ids: Vec<u32>,
|
||||
all_token_ids: Vec<TokenIdType>,
|
||||
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
@@ -27,7 +27,7 @@ pub struct DecodeStream {
|
||||
impl DecodeStream {
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
prompt_token_ids: &[u32],
|
||||
prompt_token_ids: &[TokenIdType],
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
let num_input_tokens = prompt_token_ids.len();
|
||||
@@ -44,7 +44,7 @@ impl DecodeStream {
|
||||
|
||||
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
|
||||
let start = Instant::now();
|
||||
|
||||
self.all_token_ids.push(id);
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use super::traits::{
|
||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||
};
|
||||
use anyhow::{Error, Result};
|
||||
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||
|
||||
@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer {
|
||||
}
|
||||
|
||||
impl Decoder for TiktokenTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
|
||||
// Convert u32 to usize for tiktoken-rs
|
||||
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
|
||||
|
||||
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
|
||||
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
|
||||
self.tokenizer
|
||||
.decode(tokens)
|
||||
.decode(token_ids.to_vec())
|
||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||
}
|
||||
}
|
||||
@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, _token: &str) -> Option<u32> {
|
||||
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
|
||||
// Tiktoken doesn't provide direct token-to-id mapping
|
||||
// We'd need to encode the token and check if it produces a single ID
|
||||
None
|
||||
}
|
||||
|
||||
fn id_to_token(&self, _id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
|
||||
// Tiktoken doesn't provide direct id-to-token mapping
|
||||
// We can only decode IDs to text
|
||||
None
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Type alias for token IDs
|
||||
pub type TokenIdType = u32;
|
||||
|
||||
/// Core encoding trait - separate from decoding for modularity
|
||||
pub trait Encoder: Send + Sync {
|
||||
@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
|
||||
|
||||
/// Core decoding trait - can be implemented independently
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Combined tokenizer trait
|
||||
pub trait Tokenizer: Encoder + Decoder {
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
|
||||
}
|
||||
|
||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||
@@ -25,29 +30,45 @@ pub enum Encoding {
|
||||
/// Hugging Face
|
||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||
/// Sentence Piece
|
||||
Sp(Vec<u32>),
|
||||
/// Tiktoken (for GPT models)
|
||||
Tiktoken(Vec<usize>),
|
||||
Sp(Vec<TokenIdType>),
|
||||
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
|
||||
Tiktoken(Vec<TokenIdType>),
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn token_ids(&self) -> Vec<u32> {
|
||||
/// Returns a reference to token IDs when possible, owned Vec for compatibility
|
||||
pub fn token_ids(&self) -> Vec<TokenIdType> {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
||||
Encoding::Sp(inner) => inner.clone(),
|
||||
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
|
||||
Encoding::Tiktoken(inner) => inner.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_ids_ref(&self) -> &[u32] {
|
||||
/// Returns a reference to token IDs where possible
|
||||
pub fn token_ids_ref(&self) -> &[TokenIdType] {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids(),
|
||||
Encoding::Sp(inner) => inner,
|
||||
Encoding::Tiktoken(_) => {
|
||||
// Tiktoken uses usize, we can't return a reference to u32
|
||||
// This is a limitation - callers should use token_ids() for Tiktoken
|
||||
&[]
|
||||
}
|
||||
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a hash of the token IDs for caching purposes
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash implementation for Encoding
|
||||
impl Hash for Encoding {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().hash(state),
|
||||
Encoding::Sp(inner) => inner.hash(state),
|
||||
Encoding::Tiktoken(inner) => inner.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user