[router] Add Configurable L0 and L1 Tokenizer Caching (#11688)
This commit is contained in:
106
sgl-router/src/tokenizer/cache/fingerprint.rs
vendored
Normal file
106
sgl-router/src/tokenizer/cache/fingerprint.rs
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
//! Tokenizer Fingerprinting for Cache Invalidation
|
||||
//!
|
||||
//! Creates a unique fingerprint of a tokenizer's configuration to detect
|
||||
//! when the tokenizer has changed and the cache needs to be cleared.
|
||||
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
use super::super::traits::Tokenizer;
|
||||
|
||||
/// A fingerprint of a tokenizer's configuration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct TokenizerFingerprint {
|
||||
/// Size of the vocabulary
|
||||
pub vocab_size: usize,
|
||||
/// Hash of a sample of vocabulary tokens (for speed)
|
||||
pub vocab_hash: u64,
|
||||
/// Hash of special tokens
|
||||
pub special_tokens_hash: u64,
|
||||
}
|
||||
|
||||
impl TokenizerFingerprint {
|
||||
/// Create a fingerprint from a tokenizer
|
||||
pub fn from_tokenizer(tokenizer: &dyn Tokenizer) -> Self {
|
||||
let vocab_size = tokenizer.vocab_size();
|
||||
let vocab_hash = Self::compute_vocab_hash(tokenizer);
|
||||
let special_tokens_hash = Self::compute_special_tokens_hash(tokenizer);
|
||||
|
||||
Self {
|
||||
vocab_size,
|
||||
vocab_hash,
|
||||
special_tokens_hash,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a hash of the vocabulary by sampling tokens
|
||||
fn compute_vocab_hash(tokenizer: &dyn Tokenizer) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let vocab_size = tokenizer.vocab_size();
|
||||
|
||||
// Sample up to 1000 tokens for speed
|
||||
let sample_size = vocab_size.min(1000);
|
||||
let step = if sample_size > 0 {
|
||||
vocab_size / sample_size
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
for i in (0..vocab_size).step_by(step.max(1)) {
|
||||
if let Some(token) = tokenizer.id_to_token(i as u32) {
|
||||
token.hash(&mut hasher);
|
||||
}
|
||||
}
|
||||
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Compute a hash of special tokens
|
||||
fn compute_special_tokens_hash(tokenizer: &dyn Tokenizer) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
|
||||
special_tokens.bos_token.hash(&mut hasher);
|
||||
special_tokens.eos_token.hash(&mut hasher);
|
||||
special_tokens.unk_token.hash(&mut hasher);
|
||||
special_tokens.sep_token.hash(&mut hasher);
|
||||
special_tokens.pad_token.hash(&mut hasher);
|
||||
special_tokens.cls_token.hash(&mut hasher);
|
||||
special_tokens.mask_token.hash(&mut hasher);
|
||||
special_tokens.additional_special_tokens.hash(&mut hasher);
|
||||
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_fingerprint_equality() {
|
||||
let tokenizer1 = MockTokenizer::new();
|
||||
let tokenizer2 = MockTokenizer::new();
|
||||
|
||||
let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer1);
|
||||
let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer2);
|
||||
|
||||
// Same tokenizer config should produce same fingerprint
|
||||
assert_eq!(fp1, fp2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fingerprint_consistency() {
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer);
|
||||
let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer);
|
||||
|
||||
// Fingerprint should be consistent
|
||||
assert_eq!(fp1, fp2);
|
||||
assert_eq!(fp1.vocab_size, tokenizer.vocab_size());
|
||||
}
|
||||
}
|
||||
220
sgl-router/src/tokenizer/cache/l0.rs
vendored
Normal file
220
sgl-router/src/tokenizer/cache/l0.rs
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
//! L0 Cache: Whole-string exact match cache
|
||||
//!
|
||||
//! This is the simplest and most effective cache layer.
|
||||
//! Key: input string → Value: full encoding result
|
||||
//!
|
||||
//! Expected hit rate: 60-90% for workloads with repeated system prompts
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
use super::super::traits::Encoding;
|
||||
|
||||
/// L0 cache implementation using DashMap for lock-free reads
|
||||
pub struct L0Cache {
|
||||
/// The cache map: input string → encoding
|
||||
map: Arc<DashMap<String, Encoding>>,
|
||||
/// Maximum number of entries before eviction
|
||||
max_entries: usize,
|
||||
/// Cache hit counter
|
||||
hits: AtomicU64,
|
||||
/// Cache miss counter
|
||||
misses: AtomicU64,
|
||||
}
|
||||
|
||||
impl L0Cache {
|
||||
/// Create a new L0 cache with the specified capacity
|
||||
pub fn new(max_entries: usize) -> Self {
|
||||
Self {
|
||||
map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
|
||||
max_entries,
|
||||
hits: AtomicU64::new(0),
|
||||
misses: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an encoding from the cache
|
||||
pub fn get(&self, key: &str) -> Option<Encoding> {
|
||||
match self.map.get(key) {
|
||||
Some(entry) => {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
Some(entry.value().clone())
|
||||
}
|
||||
None => {
|
||||
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert an encoding into the cache
|
||||
pub fn insert(&self, key: String, value: Encoding) {
|
||||
// Simple eviction: if we're at capacity, remove a random entry
|
||||
// DashMap doesn't support LRU directly, so we use a simple strategy
|
||||
if self.map.len() >= self.max_entries {
|
||||
// Get the key to remove in a separate scope to ensure iterator is dropped
|
||||
let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) }; // Iterator fully dropped here, all locks released
|
||||
|
||||
// Now remove it
|
||||
if let Some(k) = key_to_remove {
|
||||
self.map.remove(&k);
|
||||
}
|
||||
}
|
||||
|
||||
self.map.insert(key, value);
|
||||
}
|
||||
|
||||
/// Get the current number of entries in the cache
|
||||
pub fn len(&self) -> usize {
|
||||
self.map.len()
|
||||
}
|
||||
|
||||
/// Check if the cache is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
let hits = self.hits.load(Ordering::Relaxed);
|
||||
let misses = self.misses.load(Ordering::Relaxed);
|
||||
let total_requests = hits + misses;
|
||||
|
||||
CacheStats {
|
||||
hits,
|
||||
misses,
|
||||
entries: self.len(),
|
||||
hit_rate: if total_requests > 0 {
|
||||
hits as f64 / total_requests as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
pub fn clear(&self) {
|
||||
self.map.clear();
|
||||
self.hits.store(0, Ordering::Relaxed);
|
||||
self.misses.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Estimate memory usage in bytes
|
||||
pub fn memory_usage(&self) -> usize {
|
||||
// Rough estimate:
|
||||
// - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead
|
||||
// - Average: ~2.2KB per entry
|
||||
self.len() * 2200
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheStats {
|
||||
pub hits: u64,
|
||||
pub misses: u64,
|
||||
pub entries: usize,
|
||||
pub hit_rate: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::traits::Encoding;
|
||||
|
||||
fn mock_encoding(tokens: Vec<u32>) -> Encoding {
|
||||
Encoding::Sp(tokens)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_get_set() {
|
||||
let cache = L0Cache::new(10);
|
||||
|
||||
// Miss
|
||||
assert!(cache.get("hello").is_none());
|
||||
|
||||
// Insert
|
||||
cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));
|
||||
|
||||
// Hit
|
||||
let result = cache.get("hello");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eviction() {
|
||||
let cache = L0Cache::new(2);
|
||||
|
||||
cache.insert("a".to_string(), mock_encoding(vec![1]));
|
||||
cache.insert("b".to_string(), mock_encoding(vec![2]));
|
||||
|
||||
// Should evict when adding third
|
||||
cache.insert("c".to_string(), mock_encoding(vec![3]));
|
||||
|
||||
// Cache should have exactly 2 entries
|
||||
assert_eq!(cache.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let cache = L0Cache::new(10);
|
||||
|
||||
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
|
||||
|
||||
// 1 miss (initial get that returned None)
|
||||
let _ = cache.get("missing");
|
||||
|
||||
// 1 hit
|
||||
let _ = cache.get("test");
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.hits, 1);
|
||||
assert_eq!(stats.misses, 1);
|
||||
assert_eq!(stats.hit_rate, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let cache = L0Cache::new(10);
|
||||
|
||||
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
|
||||
assert_eq!(cache.len(), 1);
|
||||
|
||||
cache.clear();
|
||||
assert_eq!(cache.len(), 0);
|
||||
assert!(cache.get("test").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_access() {
|
||||
use std::thread;
|
||||
|
||||
let cache = Arc::new(L0Cache::new(1000));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 10 threads
|
||||
for i in 0..10 {
|
||||
let cache_clone = cache.clone();
|
||||
handles.push(thread::spawn(move || {
|
||||
// Each thread inserts and reads
|
||||
let key = format!("key_{}", i);
|
||||
cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));
|
||||
|
||||
// Read it back
|
||||
let result = cache_clone.get(&key);
|
||||
assert!(result.is_some());
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Should have 10 entries
|
||||
assert_eq!(cache.len(), 10);
|
||||
}
|
||||
}
|
||||
507
sgl-router/src/tokenizer/cache/l1.rs
vendored
Normal file
507
sgl-router/src/tokenizer/cache/l1.rs
vendored
Normal file
@@ -0,0 +1,507 @@
|
||||
//! L1 Cache: Special-token boundary prefix cache
|
||||
//!
|
||||
//! Caches tokenization results at ALL special token boundaries.
|
||||
//! Special tokens (like `<|im_start|>`, `<|im_end|>`) are atomic in BPE tokenizers (special: true, normalized: false),
|
||||
//! making them the ONLY safe split points that guarantee correctness.
|
||||
//!
|
||||
//! **Design**: Cache at every special token boundary (not at fixed granularity intervals)
|
||||
//! - Simple: No granularity parameter, no search windows
|
||||
//! - Efficient: Fewer cache entries (10 instead of 64 for typical 8KB prompt)
|
||||
//! - Natural: Aligns with actual chat template structure
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! Template: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\n{query}<|im_end|>"
|
||||
//!
|
||||
//! Request 1: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>"
|
||||
//! Request 2: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>"
|
||||
//!
|
||||
//! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe)
|
||||
//! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
|
||||
|
||||
use std::{
|
||||
mem::size_of,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use blake3;
|
||||
use dashmap::DashMap;
|
||||
|
||||
use super::super::traits::TokenIdType;
|
||||
|
||||
/// Hash type for cache keys
|
||||
type Blake3Hash = [u8; 32];
|
||||
|
||||
/// Number of shards for concurrent access
|
||||
const NUM_SHARDS: usize = 16;
|
||||
|
||||
/// Find ALL special token boundaries in the text
|
||||
///
|
||||
/// **ONLY uses special tokens** - these are atomic (special: true, normalized: false) in BPE,
|
||||
/// guaranteeing: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
|
||||
///
|
||||
/// No fallback to whitespace/punctuation - better to not cache than risk corruption.
|
||||
///
|
||||
/// Common special tokens:
|
||||
/// - ChatML: `<|im_start|>`, `<|im_end|>`
|
||||
/// - Llama 3: `<|begin_of_text|>`, `<|end_of_text|>`, `<|eot_id|>`
|
||||
/// - GPT: `<|endoftext|>`
|
||||
/// - Custom: `<|reserved_special_token_N|>`
|
||||
///
|
||||
/// Returns positions immediately after each special token (where prefixes can be cached).
|
||||
fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
|
||||
if special_tokens.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut boundaries = Vec::new();
|
||||
|
||||
// Find all special token end positions
|
||||
for &token in special_tokens {
|
||||
let mut start = 0;
|
||||
while let Some(pos) = text[start..].find(token) {
|
||||
let boundary = start + pos + token.len();
|
||||
// Only cache boundaries that leave some suffix to tokenize
|
||||
if boundary < text.len() {
|
||||
boundaries.push(boundary);
|
||||
}
|
||||
start = boundary;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort and deduplicate (in case multiple special tokens end at same position)
|
||||
boundaries.sort_unstable();
|
||||
boundaries.dedup();
|
||||
|
||||
boundaries
|
||||
}
|
||||
|
||||
/// A cached prefix entry
|
||||
#[derive(Debug, Clone)]
|
||||
struct CachedPrefix {
|
||||
/// The pre-computed token IDs for this prefix
|
||||
tokens: Vec<TokenIdType>,
|
||||
/// Last access timestamp (for LRU eviction)
|
||||
last_accessed: Arc<AtomicU64>,
|
||||
/// Size in bytes (for memory tracking during eviction)
|
||||
size_bytes: usize,
|
||||
}
|
||||
|
||||
/// L1 cache implementation with special-token-boundary prefix matching
|
||||
pub struct L1Cache {
|
||||
/// Sharded maps for concurrent access
|
||||
/// Key: Blake3 hash of bytes[0..boundary]
|
||||
/// Value: Cached token IDs for that prefix
|
||||
shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
|
||||
/// Maximum memory in bytes
|
||||
max_memory: usize,
|
||||
/// Current memory usage estimate
|
||||
current_memory: AtomicU64,
|
||||
/// Cache hit counter
|
||||
hits: AtomicU64,
|
||||
/// Cache miss counter
|
||||
misses: AtomicU64,
|
||||
/// Monotonic counter for LRU timestamps
|
||||
access_counter: AtomicU64,
|
||||
}
|
||||
|
||||
impl L1Cache {
|
||||
/// Create a new L1 cache with the specified memory limit
|
||||
pub fn new(max_memory: usize) -> Self {
|
||||
let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
|
||||
|
||||
Self {
|
||||
shards,
|
||||
max_memory,
|
||||
current_memory: AtomicU64::new(0),
|
||||
hits: AtomicU64::new(0),
|
||||
misses: AtomicU64::new(0),
|
||||
access_counter: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to find the longest prefix match at special token boundaries
|
||||
/// Returns (cached_tokens, byte_offset) if found
|
||||
///
|
||||
/// Uses pre-computed tokens cached during insertion.
|
||||
pub fn longest_prefix_match(
|
||||
&self,
|
||||
input: &str,
|
||||
special_tokens: &[&str],
|
||||
) -> Option<(Vec<TokenIdType>, usize)> {
|
||||
let boundaries = find_special_token_boundaries(input, special_tokens);
|
||||
|
||||
if boundaries.is_empty() {
|
||||
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Search backwards from the longest boundary to find the best match
|
||||
for &boundary_pos in boundaries.iter().rev() {
|
||||
let prefix = &input[0..boundary_pos];
|
||||
let prefix_bytes = prefix.as_bytes();
|
||||
let hash = blake3::hash(prefix_bytes);
|
||||
let hash_bytes: Blake3Hash = *hash.as_bytes();
|
||||
|
||||
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
|
||||
|
||||
if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
|
||||
// Update last accessed timestamp for LRU
|
||||
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
|
||||
entry.last_accessed.store(timestamp, Ordering::Relaxed);
|
||||
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
return Some((entry.tokens.clone(), boundary_pos));
|
||||
}
|
||||
}
|
||||
|
||||
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
|
||||
/// Insert prefix entries at ALL special token boundaries
|
||||
///
|
||||
/// Re-tokenizes each prefix to ensure correctness (BPE tokenization is not prefix-stable).
|
||||
/// This is more expensive on cache misses but provides correct tokens for cache hits.
|
||||
///
|
||||
/// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts).
|
||||
pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
|
||||
&self,
|
||||
input: &str,
|
||||
tokenizer: &E,
|
||||
special_tokens: &[&str],
|
||||
) -> anyhow::Result<()> {
|
||||
let boundaries = find_special_token_boundaries(input, special_tokens);
|
||||
|
||||
if boundaries.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Calculate how much memory we need and tokenize each prefix
|
||||
let mut entries_to_insert = Vec::new();
|
||||
for &boundary_pos in &boundaries {
|
||||
// Extract prefix up to this special token boundary
|
||||
let prefix = &input[0..boundary_pos];
|
||||
let prefix_bytes = prefix.as_bytes();
|
||||
let hash = blake3::hash(prefix_bytes);
|
||||
let hash_bytes: Blake3Hash = *hash.as_bytes();
|
||||
|
||||
// Re-tokenize the prefix for guaranteed correctness
|
||||
// This is the only way to know the exact token boundaries
|
||||
let prefix_encoding = tokenizer.encode(prefix)?;
|
||||
let prefix_tokens = prefix_encoding.token_ids().to_vec();
|
||||
|
||||
// Size = text bytes + token storage
|
||||
let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
|
||||
|
||||
entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
|
||||
}
|
||||
|
||||
if entries_to_insert.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
|
||||
|
||||
// Evict if necessary
|
||||
let current = self.current_memory.load(Ordering::Relaxed) as usize;
|
||||
if current + total_size_needed > self.max_memory {
|
||||
self.evict_lru(total_size_needed);
|
||||
}
|
||||
|
||||
// Insert all entries
|
||||
for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
|
||||
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
|
||||
|
||||
let cached = CachedPrefix {
|
||||
tokens: prefix_tokens,
|
||||
last_accessed: Arc::new(AtomicU64::new(
|
||||
self.access_counter.load(Ordering::Relaxed),
|
||||
)),
|
||||
size_bytes,
|
||||
};
|
||||
|
||||
self.shards[shard_idx].insert(hash_bytes, cached);
|
||||
self.current_memory
|
||||
.fetch_add(size_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict least recently used entries using approximate LRU via random sampling
|
||||
///
|
||||
/// This uses an approximate LRU strategy that's much faster than true LRU:
|
||||
/// - Samples K random entries from the cache (K=32)
|
||||
/// - Evicts the oldest entry among the samples
|
||||
/// - Repeats until enough space is freed
|
||||
///
|
||||
/// This provides O(samples) complexity instead of O(total_entries * log(total_entries)),
|
||||
/// avoiding latency spikes when eviction is triggered on large caches.
|
||||
///
|
||||
/// The approximation is excellent in practice - sampling 32 entries from a large cache
|
||||
/// gives high probability of finding very old entries.
|
||||
fn evict_lru(&self, space_needed: usize) {
|
||||
const SAMPLE_SIZE: usize = 32; // Number of entries to sample per eviction round
|
||||
let mut freed = 0usize;
|
||||
let mut iteration = 0usize;
|
||||
|
||||
// Keep evicting until we have enough space
|
||||
while freed < space_needed {
|
||||
// Collect samples from shards
|
||||
let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
|
||||
|
||||
// Sample entries across different shards
|
||||
for i in 0..SAMPLE_SIZE {
|
||||
// Distribute samples across shards using iteration and index for variety
|
||||
let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
|
||||
|
||||
// Get first entry from that shard (DashMap iteration order is arbitrary)
|
||||
if let Some(entry) = self.shards[shard_idx].iter().next() {
|
||||
let hash = *entry.key();
|
||||
let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
|
||||
let size = entry.value().size_bytes;
|
||||
samples.push((shard_idx, hash, timestamp, size));
|
||||
}
|
||||
}
|
||||
|
||||
if samples.is_empty() {
|
||||
// Cache is empty, nothing to evict
|
||||
break;
|
||||
}
|
||||
|
||||
// Find the oldest entry among samples
|
||||
if let Some((shard_idx, hash, _, _)) =
|
||||
samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
|
||||
{
|
||||
// Remove it
|
||||
if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
|
||||
freed += removed.size_bytes;
|
||||
self.current_memory
|
||||
.fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
iteration += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of entries in the cache
|
||||
pub fn len(&self) -> usize {
|
||||
self.shards.iter().map(|s| s.len()).sum()
|
||||
}
|
||||
|
||||
/// Check if the cache is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.shards.iter().all(|s| s.is_empty())
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn stats(&self) -> L1CacheStats {
|
||||
let hits = self.hits.load(Ordering::Relaxed);
|
||||
let misses = self.misses.load(Ordering::Relaxed);
|
||||
let total_requests = hits + misses;
|
||||
|
||||
L1CacheStats {
|
||||
hits,
|
||||
misses,
|
||||
entries: self.len(),
|
||||
memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
|
||||
hit_rate: if total_requests > 0 {
|
||||
hits as f64 / total_requests as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
pub fn clear(&self) {
|
||||
for shard in &self.shards {
|
||||
shard.clear();
|
||||
}
|
||||
self.current_memory.store(0, Ordering::Relaxed);
|
||||
self.hits.store(0, Ordering::Relaxed);
|
||||
self.misses.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct L1CacheStats {
|
||||
pub hits: u64,
|
||||
pub misses: u64,
|
||||
pub entries: usize,
|
||||
pub memory_bytes: usize,
|
||||
pub hit_rate: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_basic_prefix_match() {
|
||||
let cache = L1Cache::new(1024 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// Realistic ChatML template with special tokens
|
||||
let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
|
||||
|
||||
// Insert at special token boundaries (re-tokenizes prefixes)
|
||||
cache
|
||||
.insert_at_boundaries(input1, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Should have cached at special token boundaries
|
||||
assert!(!cache.is_empty());
|
||||
|
||||
// Search with same prefix but different user query
|
||||
let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
|
||||
let result = cache.longest_prefix_match(input2, special_tokens);
|
||||
|
||||
// Should find a match at the special token boundary (after system message)
|
||||
assert!(result.is_some());
|
||||
let (tokens, offset) = result.unwrap();
|
||||
assert!(offset > 0);
|
||||
assert!(!tokens.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_input_with_boundaries() {
|
||||
let cache = L1Cache::new(1024 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// Short input with special tokens
|
||||
let input = "<|im_start|>user\nHi<|im_end|>";
|
||||
|
||||
cache
|
||||
.insert_at_boundaries(input, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Should cache at <|im_start|> boundary (has suffix left)
|
||||
assert!(!cache.is_empty());
|
||||
|
||||
// Should find a match
|
||||
let result = cache.longest_prefix_match(input, special_tokens);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_longest_match() {
|
||||
let cache = L1Cache::new(1024 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// Create multi-turn conversation with multiple special token boundaries (~400 bytes)
|
||||
let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
|
||||
|
||||
cache
|
||||
.insert_at_boundaries(input, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Should have multiple entries at special token boundaries
|
||||
assert!(cache.len() >= 2); // At least 2 boundaries
|
||||
|
||||
// Search with partial conversation - should match at a special token boundary
|
||||
let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
|
||||
let result = cache.longest_prefix_match(partial_input, special_tokens);
|
||||
|
||||
// Should find a match at a special token boundary
|
||||
assert!(result.is_some());
|
||||
let (_, offset) = result.unwrap();
|
||||
assert!(offset > 0);
|
||||
assert!(offset <= partial_input.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let cache = L1Cache::new(1024 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// ChatML input with special tokens
|
||||
let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
|
||||
|
||||
cache
|
||||
.insert_at_boundaries(input, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Try to find match
|
||||
let _ = cache.longest_prefix_match(input, special_tokens);
|
||||
|
||||
let stats = cache.stats();
|
||||
// Should have at least one hit (the longest special token boundary should match)
|
||||
assert!(stats.hits >= 1);
|
||||
assert_eq!(stats.hit_rate, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let cache = L1Cache::new(1024 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// ChatML input with special tokens
|
||||
let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
|
||||
|
||||
cache
|
||||
.insert_at_boundaries(input, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
assert!(!cache.is_empty());
|
||||
|
||||
cache.clear();
|
||||
assert!(cache.is_empty());
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.hits, 0);
|
||||
assert_eq!(stats.misses, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lru_eviction() {
|
||||
// Create a small cache (5KB) to trigger eviction
|
||||
let cache = L1Cache::new(5 * 1024);
|
||||
let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
|
||||
let tokenizer = MockTokenizer::new();
|
||||
|
||||
// Insert first conversation
|
||||
let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
|
||||
cache
|
||||
.insert_at_boundaries(input1, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Access the first entry to update its timestamp
|
||||
let result = cache.longest_prefix_match(input1, special_tokens);
|
||||
assert!(result.is_some());
|
||||
|
||||
// Insert second conversation
|
||||
let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
|
||||
cache
|
||||
.insert_at_boundaries(input2, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Access the second entry to make it more recent
|
||||
let result = cache.longest_prefix_match(input2, special_tokens);
|
||||
assert!(result.is_some());
|
||||
|
||||
// Insert third conversation (should trigger eviction of oldest)
|
||||
let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
|
||||
cache
|
||||
.insert_at_boundaries(input3, &tokenizer, special_tokens)
|
||||
.unwrap();
|
||||
|
||||
// Verify cache didn't exceed max memory
|
||||
let stats = cache.stats();
|
||||
assert!(stats.memory_bytes <= 5 * 1024);
|
||||
|
||||
// The most recently accessed entries should still be present
|
||||
let result = cache.longest_prefix_match(input3, special_tokens);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
}
|
||||
362
sgl-router/src/tokenizer/cache/mod.rs
vendored
Normal file
362
sgl-router/src/tokenizer/cache/mod.rs
vendored
Normal file
@@ -0,0 +1,362 @@
|
||||
//! Tokenizer Caching Layer
|
||||
//!
|
||||
//! Provides a caching wrapper around any tokenizer implementation to speed up
|
||||
//! repeated tokenization of the same strings (e.g., system prompts).
|
||||
//!
|
||||
//! # Architecture
|
||||
//! - **L0 Cache**: Whole-string exact match (90% of wins)
|
||||
//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
|
||||
//!
|
||||
//! # Usage
|
||||
//! ```ignore
|
||||
//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
|
||||
//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
|
||||
//! let encoding = cached.encode("Hello world")?;
|
||||
//! ```
|
||||
|
||||
mod fingerprint;
|
||||
mod l0;
|
||||
mod l1;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
pub use fingerprint::TokenizerFingerprint;
|
||||
pub use l0::{CacheStats, L0Cache};
|
||||
pub use l1::{L1Cache, L1CacheStats};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer};
|
||||
|
||||
/// Configuration for the tokenizer cache
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
/// Enable L0 (whole-string) cache
|
||||
pub enable_l0: bool,
|
||||
/// Maximum number of entries in L0 cache
|
||||
pub l0_max_entries: usize,
|
||||
/// Enable L1 (prefix) cache
|
||||
pub enable_l1: bool,
|
||||
/// Maximum memory for L1 cache in bytes
|
||||
pub l1_max_memory: usize,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_l0: true,
|
||||
l0_max_entries: 10_000, // ~22MB memory for typical prompts
|
||||
enable_l1: false, // Opt-in for now
|
||||
l1_max_memory: 50 * 1024 * 1024, // 50MB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A caching wrapper around any tokenizer
|
||||
pub struct CachedTokenizer {
|
||||
/// The underlying tokenizer
|
||||
inner: Arc<dyn Tokenizer>,
|
||||
/// L0 cache (whole-string exact match)
|
||||
l0: Option<L0Cache>,
|
||||
/// L1 cache (prefix matching at fixed boundaries)
|
||||
l1: Option<L1Cache>,
|
||||
/// Configuration
|
||||
#[allow(dead_code)]
|
||||
config: CacheConfig,
|
||||
/// Fingerprint for cache invalidation
|
||||
fingerprint: TokenizerFingerprint,
|
||||
/// Cached special token strings (extracted once at construction)
|
||||
special_token_strings: Vec<String>,
|
||||
}
|
||||
|
||||
impl CachedTokenizer {
|
||||
/// Create a new cached tokenizer
|
||||
pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
|
||||
let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
|
||||
|
||||
let l0 = if config.enable_l0 {
|
||||
Some(L0Cache::new(config.l0_max_entries))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let l1 = if config.enable_l1 {
|
||||
Some(L1Cache::new(config.l1_max_memory))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Extract special tokens once at construction time
|
||||
let special_token_strings = Self::extract_special_token_strings(&inner);
|
||||
|
||||
Self {
|
||||
inner,
|
||||
l0,
|
||||
l1,
|
||||
config,
|
||||
fingerprint,
|
||||
special_token_strings,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all special token strings from the tokenizer (called once at construction)
|
||||
fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
if let Some(ref token) = special_tokens.bos_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.eos_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.unk_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.sep_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.pad_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.cls_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
if let Some(ref token) = special_tokens.mask_token {
|
||||
tokens.push(token.clone());
|
||||
}
|
||||
|
||||
tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
|
||||
tokens
|
||||
}
|
||||
|
||||
/// Get L0 cache statistics
|
||||
pub fn cache_stats(&self) -> Option<CacheStats> {
|
||||
self.l0.as_ref().map(|cache| cache.stats())
|
||||
}
|
||||
|
||||
/// Get L1 cache statistics
|
||||
pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
|
||||
self.l1.as_ref().map(|cache| cache.stats())
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
pub fn clear_cache(&self) {
|
||||
if let Some(l0) = &self.l0 {
|
||||
l0.clear();
|
||||
}
|
||||
if let Some(l1) = &self.l1 {
|
||||
l1.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the fingerprint of the underlying tokenizer
|
||||
pub fn fingerprint(&self) -> &TokenizerFingerprint {
|
||||
&self.fingerprint
|
||||
}
|
||||
|
||||
/// Get a reference to the inner (wrapped) tokenizer
|
||||
pub fn inner(&self) -> &Arc<dyn Tokenizer> {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for CachedTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
// Collect special tokens once if L1 is enabled (avoid redundant allocation)
|
||||
let special_tokens: Option<Vec<&str>> = self.l1.as_ref().map(|_| {
|
||||
self.special_token_strings
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect()
|
||||
});
|
||||
|
||||
// L0 cache lookup (exact match)
|
||||
if let Some(l0) = &self.l0 {
|
||||
if let Some(cached) = l0.get(input) {
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
// L1 cache lookup (prefix match at special token boundaries)
|
||||
if let Some(l1) = &self.l1 {
|
||||
let tokens = special_tokens.as_ref().unwrap();
|
||||
|
||||
if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, tokens) {
|
||||
// We have a prefix match - tokenize the suffix
|
||||
let suffix = &input[prefix_len..];
|
||||
if !suffix.is_empty() {
|
||||
let suffix_encoding = self.inner.encode(suffix)?;
|
||||
|
||||
// Merge prefix tokens + suffix tokens
|
||||
// Safe because we're splitting at special token boundaries
|
||||
let mut merged_tokens = prefix_tokens;
|
||||
merged_tokens.extend_from_slice(suffix_encoding.token_ids());
|
||||
|
||||
let merged_encoding = Encoding::Sp(merged_tokens);
|
||||
|
||||
// Cache the full result in L0
|
||||
if let Some(l0) = &self.l0 {
|
||||
l0.insert(input.to_string(), merged_encoding.clone());
|
||||
}
|
||||
|
||||
return Ok(merged_encoding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Full tokenization (both L0 and L1 miss)
|
||||
let encoding = self.inner.encode(input)?;
|
||||
|
||||
// Cache in L0
|
||||
if let Some(l0) = &self.l0 {
|
||||
l0.insert(input.to_string(), encoding.clone());
|
||||
}
|
||||
|
||||
// Cache in L1 at special token boundaries
|
||||
// Re-tokenizes prefixes for correctness (optimized for high prefix reuse)
|
||||
if let Some(l1) = &self.l1 {
|
||||
let tokens = special_tokens.as_ref().unwrap();
|
||||
let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), tokens);
|
||||
// Ignore errors in cache insertion - cache is best-effort
|
||||
}
|
||||
|
||||
Ok(encoding)
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
// Process each input in parallel, leveraging thread-safe caches
|
||||
// This maintains the parallelism from the underlying HuggingFaceTokenizer
|
||||
inputs.par_iter().map(|&input| self.encode(input)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for CachedTokenizer {
|
||||
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
|
||||
// Decoding is not cached (it's fast enough and rarely repeated)
|
||||
self.inner.decode(token_ids, skip_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl Tokenizer for CachedTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.inner.vocab_size()
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
self.inner.get_special_tokens()
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
|
||||
self.inner.token_to_id(token)
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
|
||||
self.inner.id_to_token(id)
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_cache_hit() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
|
||||
|
||||
let input = "Hello world";
|
||||
|
||||
// First call - miss
|
||||
let result1 = cached.encode(input).unwrap();
|
||||
|
||||
// Second call - hit
|
||||
let result2 = cached.encode(input).unwrap();
|
||||
|
||||
// Results should be identical
|
||||
assert_eq!(result1.token_ids(), result2.token_ids());
|
||||
|
||||
// Check cache stats
|
||||
let stats = cached.cache_stats().unwrap();
|
||||
assert_eq!(stats.hits, 1);
|
||||
assert_eq!(stats.misses, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_disabled() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = CacheConfig {
|
||||
enable_l0: false,
|
||||
l0_max_entries: 0,
|
||||
enable_l1: false,
|
||||
l1_max_memory: 0,
|
||||
};
|
||||
let cached = CachedTokenizer::new(tokenizer, config);
|
||||
|
||||
let input = "Hello world";
|
||||
|
||||
// Both calls should work even without cache
|
||||
let result1 = cached.encode(input).unwrap();
|
||||
let result2 = cached.encode(input).unwrap();
|
||||
|
||||
assert_eq!(result1.token_ids(), result2.token_ids());
|
||||
|
||||
// No cache stats available
|
||||
assert!(cached.cache_stats().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_batch() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
|
||||
|
||||
let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
|
||||
|
||||
let results = cached.encode_batch(&inputs).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
|
||||
// With parallel execution, duplicate inputs may be processed simultaneously
|
||||
// and both see cache misses. Verify results are correct instead.
|
||||
assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
|
||||
|
||||
// After batch processing, cache should be populated
|
||||
// Subsequent calls should hit the cache
|
||||
let _ = cached.encode("Hello").unwrap();
|
||||
let stats = cached.cache_stats().unwrap();
|
||||
|
||||
// Should have at least 1 hit from the call above (cache was populated by batch)
|
||||
assert!(
|
||||
stats.hits >= 1,
|
||||
"Expected at least 1 cache hit after batch processing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decoder_passthrough() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
|
||||
|
||||
let tokens = vec![1, 2, 3];
|
||||
let decoded = cached.decode(&tokens, false).unwrap();
|
||||
|
||||
// Should just pass through to inner tokenizer
|
||||
assert!(!decoded.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_trait_methods() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
|
||||
|
||||
// Should pass through to inner tokenizer
|
||||
assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
|
||||
assert!(cached.token_to_id("Hello").is_some());
|
||||
assert!(cached.id_to_token(1).is_some());
|
||||
}
|
||||
}
|
||||
@@ -407,7 +407,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_mock_tokenizer_creation() {
|
||||
let tokenizer = create_tokenizer_from_file("mock").unwrap();
|
||||
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
|
||||
assert_eq!(tokenizer.vocab_size(), 14); // Mock tokenizer has 14 tokens
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -44,8 +44,8 @@ impl HuggingFaceTokenizer {
|
||||
// Extract special tokens
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
// Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
|
||||
let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
@@ -80,7 +80,7 @@ impl HuggingFaceTokenizer {
|
||||
/// Create from an existing HuggingFace tokenizer
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
|
||||
let reverse_vocab: HashMap<TokenIdType, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
@@ -98,8 +98,7 @@ impl HuggingFaceTokenizer {
|
||||
|
||||
/// Extract special tokens from the tokenizer
|
||||
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
|
||||
// Try to get special tokens from the tokenizer
|
||||
// This is a simplified version - actual implementation would need to handle various formats
|
||||
// Get vocab with special tokens included (added_tokens like <|im_start|>)
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
|
||||
let find_token = |patterns: &[&str]| -> Option<String> {
|
||||
@@ -111,6 +110,14 @@ impl HuggingFaceTokenizer {
|
||||
None
|
||||
};
|
||||
|
||||
// Extract additional special tokens using the tokenizers library API
|
||||
let additional_special_tokens: Vec<String> = tokenizer
|
||||
.get_added_tokens_decoder()
|
||||
.iter()
|
||||
.filter(|(_id, token)| token.special) // Only tokens marked as special: true
|
||||
.map(|(_id, token)| token.content.clone())
|
||||
.collect();
|
||||
|
||||
SpecialTokens {
|
||||
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
|
||||
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
|
||||
@@ -119,7 +126,7 @@ impl HuggingFaceTokenizer {
|
||||
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
|
||||
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
|
||||
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
|
||||
additional_special_tokens: vec![],
|
||||
additional_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,12 @@ impl MockTokenizer {
|
||||
(".", 6),
|
||||
("<eos>", 999),
|
||||
("<bos>", 1000),
|
||||
("<|im_start|>", 1001),
|
||||
("<|im_end|>", 1002),
|
||||
("<|eot_id|>", 1003),
|
||||
("system", 7),
|
||||
("user", 8),
|
||||
("assistant", 9),
|
||||
];
|
||||
|
||||
for (token, id) in tokens {
|
||||
@@ -62,7 +68,8 @@ impl MockTokenizer {
|
||||
|
||||
impl Encoder for MockTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
// Simple word-based tokenization for testing
|
||||
// Simple word-based tokenization using the vocab
|
||||
// Split by whitespace and look up each word (decoder adds spaces back)
|
||||
let tokens: Vec<u32> = input
|
||||
.split_whitespace()
|
||||
.filter_map(|word| self.vocab.get(word).copied())
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
pub mod cache;
|
||||
pub mod factory;
|
||||
pub mod hub;
|
||||
pub mod mock;
|
||||
@@ -22,6 +23,7 @@ pub mod tiktoken;
|
||||
mod tests;
|
||||
|
||||
// Re-exports
|
||||
pub use cache::{CacheConfig, CacheStats, CachedTokenizer, TokenizerFingerprint};
|
||||
pub use factory::{
|
||||
create_tokenizer, create_tokenizer_async, create_tokenizer_async_with_chat_template,
|
||||
create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
||||
|
||||
@@ -43,7 +43,7 @@ fn test_tokenizer_wrapper() {
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
|
||||
assert_eq!(tokenizer.vocab_size(), 8);
|
||||
assert_eq!(tokenizer.vocab_size(), 14);
|
||||
|
||||
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
||||
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
||||
|
||||
Reference in New Issue
Block a user