[router] add tokenizer chat template support (#9370)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Type alias for token IDs
|
||||
pub type TokenIdType = u32;
|
||||
|
||||
/// Core encoding trait - separate from decoding for modularity
|
||||
pub trait Encoder: Send + Sync {
|
||||
@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
|
||||
|
||||
/// Core decoding trait - can be implemented independently
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Combined tokenizer trait
|
||||
pub trait Tokenizer: Encoder + Decoder {
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
|
||||
}
|
||||
|
||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||
@@ -25,29 +30,45 @@ pub enum Encoding {
|
||||
/// Hugging Face
|
||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||
/// Sentence Piece
|
||||
Sp(Vec<u32>),
|
||||
/// Tiktoken (for GPT models)
|
||||
Tiktoken(Vec<usize>),
|
||||
Sp(Vec<TokenIdType>),
|
||||
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
|
||||
Tiktoken(Vec<TokenIdType>),
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn token_ids(&self) -> Vec<u32> {
|
||||
/// Returns a reference to token IDs when possible, owned Vec for compatibility
|
||||
pub fn token_ids(&self) -> Vec<TokenIdType> {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
||||
Encoding::Sp(inner) => inner.clone(),
|
||||
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
|
||||
Encoding::Tiktoken(inner) => inner.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_ids_ref(&self) -> &[u32] {
|
||||
/// Returns a reference to token IDs where possible
|
||||
pub fn token_ids_ref(&self) -> &[TokenIdType] {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids(),
|
||||
Encoding::Sp(inner) => inner,
|
||||
Encoding::Tiktoken(_) => {
|
||||
// Tiktoken uses usize, we can't return a reference to u32
|
||||
// This is a limitation - callers should use token_ids() for Tiktoken
|
||||
&[]
|
||||
}
|
||||
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a hash of the token IDs for caching purposes
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash implementation for Encoding
|
||||
impl Hash for Encoding {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids().hash(state),
|
||||
Encoding::Sp(inner) => inner.hash(state),
|
||||
Encoding::Tiktoken(inner) => inner.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user