[router] Add Configurable L0 and L1 Tokenizer Caching (#11688)

This commit is contained in:
Simo Lin
2025-10-18 18:33:53 -07:00
committed by GitHub
parent fda0cb2a30
commit a7ae61ed77
22 changed files with 2385 additions and 24 deletions

View File

@@ -14,19 +14,34 @@ use std::{
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tokenizer::{
huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*,
cache::{CacheConfig, CachedTokenizer},
huggingface::HuggingFaceTokenizer,
sequence::Sequence,
stop::*,
stream::DecodeStream,
traits::*,
};
// Include the common test utilities
#[path = "../tests/common/mod.rs"]
mod common;
use common::ensure_tokenizer_cached;
// Cache the tokenizer path for the entire benchmark run
static TOKENIZER_PATH: OnceLock<PathBuf> = OnceLock::new();
fn get_tokenizer_path() -> &'static PathBuf {
TOKENIZER_PATH.get_or_init(ensure_tokenizer_cached)
TOKENIZER_PATH.get_or_init(|| {
// Use Qwen3-4B-Instruct which has ChatML special tokens (<|im_start|>, <|im_end|>)
// with special: true, normalized: false - perfect for demonstrating L1 cache
let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
let tokenizer_dir = rt.block_on(async {
sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf(
"Qwen/Qwen3-4B-Instruct-2507",
)
.await
.expect("Failed to download Qwen3-4B-Instruct tokenizer from HuggingFace")
});
// The download_tokenizer_from_hf returns the directory containing tokenizer.json
// We need to construct the full path to tokenizer.json
tokenizer_dir.join("tokenizer.json")
})
}
// Production target: 100k tokens per second
@@ -1253,6 +1268,468 @@ fn bench_scaling_characteristics(c: &mut Criterion) {
group.finish();
}
fn bench_l1_cache_chat_template(c: &mut Criterion) {
let tokenizer_path = get_tokenizer_path();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
let mut group = c.benchmark_group("l1_cache_chat");
// ============================================================================
// SCENARIO 1: High Prefix Reuse (95%+ - Realistic Chat Application)
// ============================================================================
// Most realistic: Same system prompt across 95%+ of requests, different user queries
// This is typical for chat applications where the same system context is reused
let system_prompt = generate_system_prompt(8000);
// Generate 100 different user queries of varying lengths (realistic distribution)
let user_queries: Vec<String> = (0..100)
.map(|i| {
let base_queries = [
"What is the capital of France?",
"Explain quantum mechanics in simple terms.",
"How do I sort an array in Python?",
"What are the benefits of exercise?",
"Can you help me write a resume?",
];
let query = base_queries[i % base_queries.len()];
// Add variation to make each unique
format!("{} (Query #{})", query, i)
})
.collect();
// Create prompts with ChatML format (same system prefix, different queries)
let realistic_prompts: Vec<String> = user_queries
.iter()
.map(|query| {
format!(
"<|im_start|>system\n{}<|im_end|><|im_start|>user\n{}<|im_end|><|im_start|>assistant\n",
system_prompt, query
)
})
.collect();
// Baseline: No cache
let printed_baseline = Arc::new(AtomicBool::new(false));
group.bench_function("realistic_chat_uncached", |b| {
let printed = printed_baseline.clone();
let tokenizer = tokenizer.clone();
let test_prompts = realistic_prompts.clone();
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
// Simulate 100 requests with different queries (realistic workload)
for prompt in &test_prompts {
black_box(tokenizer.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | {:>20}",
"Uncached (baseline)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
"N/A"
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// L0-only: Should have 0% hit rate (all queries are unique)
let l0_only_config = CacheConfig {
enable_l0: true,
l0_max_entries: 10_000,
enable_l1: false,
l1_max_memory: 0,
};
let cached_l0_only = Arc::new(CachedTokenizer::new(tokenizer.clone(), l0_only_config));
let printed_l0 = Arc::new(AtomicBool::new(false));
group.bench_function("realistic_chat_l0_only", |b| {
let printed = printed_l0.clone();
let cached = cached_l0_only.clone();
let test_prompts = realistic_prompts.clone();
b.iter_custom(|iters| {
cached.clear_cache(); // Start fresh each iteration
let start = Instant::now();
for _ in 0..iters {
for prompt in &test_prompts {
black_box(cached.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let stats = cached.cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6}",
"L0-only (no benefit)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
stats.hit_rate * 100.0,
"N/A"
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// L0+L1: Should show significant speedup from prefix caching
let l0_l1_config = CacheConfig {
enable_l0: true,
l0_max_entries: 10_000,
enable_l1: true,
l1_max_memory: 50 * 1024 * 1024,
};
let cached_l0_l1 = Arc::new(CachedTokenizer::new(
tokenizer.clone(),
l0_l1_config.clone(),
));
let printed_l0_l1 = Arc::new(AtomicBool::new(false));
group.bench_function("realistic_chat_l0_l1", |b| {
let printed = printed_l0_l1.clone();
let cached = cached_l0_l1.clone();
let test_prompts = realistic_prompts.clone();
b.iter_custom(|iters| {
cached.clear_cache(); // Start fresh
// Prime with first request to populate L1 with system prefix
cached.encode(&test_prompts[0]).unwrap();
let start = Instant::now();
for _ in 0..iters {
// All subsequent requests benefit from L1 prefix cache
for prompt in &test_prompts {
black_box(cached.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let stats = cached.cache_stats().unwrap();
let l1_stats = cached.l1_cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%",
"L0+L1 (95%+ prefix reuse)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
stats.hit_rate * 100.0,
l1_stats.hit_rate * 100.0
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// ============================================================================
// SCENARIO 2: Customer Service Bot (100% prefix reuse)
// ============================================================================
// Identical greeting/instructions with different customer queries
let service_system = "You are a helpful customer service assistant for TechCorp. \
Always be polite, professional, and helpful. Our business hours are 9 AM to 5 PM EST. \
We offer a 30-day return policy on all products. For technical issues, escalate to technical support. \
For billing issues, escalate to accounting department.".repeat(20); // ~2KB
let customer_queries = [
"I need to return my laptop",
"My order hasn't arrived yet",
"How do I reset my password?",
"What's your return policy?",
"I was charged twice for my order",
"Can I change my shipping address?",
"Is my product under warranty?",
"I need help installing the software",
];
let service_prompts: Vec<String> = customer_queries
.iter()
.map(|query| {
format!(
"<|im_start|>system\n{}<|im_end|><|im_start|>user\n{}<|im_end|><|im_start|>assistant\n",
service_system, query
)
})
.collect();
// Service bot with L1-only (to compare against L0+L1)
let l1_only_config = CacheConfig {
enable_l0: false,
l0_max_entries: 0,
enable_l1: true,
l1_max_memory: 50 * 1024 * 1024,
};
let service_cached_l1 = Arc::new(CachedTokenizer::new(tokenizer.clone(), l1_only_config));
let printed_service_l1 = Arc::new(AtomicBool::new(false));
group.bench_function("customer_service_l1_only", |b| {
let printed = printed_service_l1.clone();
let cached = service_cached_l1.clone();
let test_prompts = service_prompts.clone();
b.iter_custom(|iters| {
cached.clear_cache();
cached.encode(&test_prompts[0]).unwrap(); // Prime cache
let start = Instant::now();
for _ in 0..iters {
for prompt in &test_prompts {
black_box(cached.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let l1_stats = cached.l1_cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6} L1:{:>6.1}%",
"Customer Service (L1-only)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
"N/A",
l1_stats.hit_rate * 100.0
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// Service bot with L0+L1
let service_cached = Arc::new(CachedTokenizer::new(
tokenizer.clone(),
l0_l1_config.clone(),
));
let printed_service = Arc::new(AtomicBool::new(false));
group.bench_function("customer_service_l0_l1", |b| {
let printed = printed_service.clone();
let cached = service_cached.clone();
let test_prompts = service_prompts.clone();
b.iter_custom(|iters| {
cached.clear_cache();
cached.encode(&test_prompts[0]).unwrap(); // Prime cache
let start = Instant::now();
for _ in 0..iters {
for prompt in &test_prompts {
black_box(cached.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let stats = cached.cache_stats().unwrap();
let l1_stats = cached.l1_cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%",
"Customer Service (100% reuse)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
stats.hit_rate * 100.0,
l1_stats.hit_rate * 100.0
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// ============================================================================
// SCENARIO 3: Multi-Turn Conversation (Progressive context building)
// ============================================================================
// Each turn builds on previous context (common in chat applications)
let conversation_system =
"You are a helpful coding assistant. Help users write better code.".repeat(10);
// Simulate a 5-turn conversation where context grows
let conversation_turns = vec![
format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\n", conversation_system),
format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\nYou can use the sorted() function or list.sort() method.<|im_end|><|im_start|>user\nWhat's the difference between them?<|im_end|><|im_start|>assistant\n", conversation_system),
format!("<|im_start|>system\n{}<|im_end|><|im_start|>user\nHow do I sort an array in Python?<|im_end|><|im_start|>assistant\nYou can use the sorted() function or list.sort() method.<|im_end|><|im_start|>user\nWhat's the difference between them?<|im_end|><|im_start|>assistant\nsorted() creates a new list, sort() modifies in place.<|im_end|><|im_start|>user\nCan I sort by a custom key?<|im_end|><|im_start|>assistant\n", conversation_system),
];
let conv_cached = Arc::new(CachedTokenizer::new(
tokenizer.clone(),
l0_l1_config.clone(),
));
let printed_conv = Arc::new(AtomicBool::new(false));
group.bench_function("multi_turn_conversation", |b| {
let printed = printed_conv.clone();
let cached = conv_cached.clone();
let test_turns = conversation_turns.clone();
b.iter_custom(|iters| {
cached.clear_cache();
let start = Instant::now();
for _ in 0..iters {
// Simulate progressive conversation (each turn shares prefix with previous)
for turn in &test_turns {
black_box(cached.encode(turn).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_turns.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let stats = cached.cache_stats().unwrap();
let l1_stats = cached.l1_cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%",
"Multi-turn Conversation",
test_turns[0].len(),
ops_per_sec,
avg_time_us,
stats.hit_rate * 100.0,
l1_stats.hit_rate * 100.0
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
// ============================================================================
// SCENARIO 4: Code Review Assistant (Same guidelines, different code snippets)
// ============================================================================
let review_system = "You are a code review assistant. Check for: \
1) Code quality and readability \
2) Performance issues \
3) Security vulnerabilities \
4) Best practices \
5) Documentation completeness"
.repeat(15);
let code_snippets = [
"function add(a, b) { return a + b; }",
"def factorial(n): return 1 if n <= 1 else n * factorial(n-1)",
"SELECT * FROM users WHERE id = $_GET['id']", // Security issue
"for (var i = 0; i < 10; i++) { setTimeout(() => console.log(i), 100); }", // Closure issue
];
let review_prompts: Vec<String> = code_snippets
.iter()
.map(|code| {
format!(
"<|im_start|>system\n{}<|im_end|><|im_start|>user\nReview this code:\n```\n{}\n```<|im_end|><|im_start|>assistant\n",
review_system, code
)
})
.collect();
let review_cached = Arc::new(CachedTokenizer::new(tokenizer.clone(), l0_l1_config));
let printed_review = Arc::new(AtomicBool::new(false));
group.bench_function("code_review_assistant", |b| {
let printed = printed_review.clone();
let cached = review_cached.clone();
let test_prompts = review_prompts.clone();
b.iter_custom(|iters| {
cached.clear_cache();
cached.encode(&test_prompts[0]).unwrap(); // Prime cache
let start = Instant::now();
for _ in 0..iters {
for prompt in &test_prompts {
black_box(cached.encode(prompt).unwrap());
}
}
let duration = start.elapsed();
if !printed.load(Ordering::Relaxed) {
let total_ops = iters * test_prompts.len() as u64;
let ops_per_sec = total_ops as f64 / duration.as_secs_f64();
let avg_time_us = duration.as_micros() as f64 / total_ops as f64;
let stats = cached.cache_stats().unwrap();
let l1_stats = cached.l1_cache_stats().unwrap();
let result = format!(
"{:<30} | {:>8} | {:>12.0} | {:>12.1} | L0:{:>6.1}% L1:{:>6.1}%",
"Code Review (high reuse)",
test_prompts[0].len(),
ops_per_sec,
avg_time_us,
stats.hit_rate * 100.0,
l1_stats.hit_rate * 100.0
);
add_result("l1_cache", result);
printed.store(true, Ordering::Relaxed);
}
duration
});
});
group.finish();
}
// Print final summary table
fn print_summary() {
println!("\n{}", "=".repeat(120));
@@ -1372,6 +1849,13 @@ fn print_summary() {
"Operation", "Calls/sec", "Time/call", "Improvement"
);
}
"l1_cache" => {
println!("L1 CACHE (PREFIX MATCHING) - REALISTIC WORKLOADS");
println!(
"{:<30} | {:>8} | {:>12} | {:>12} | {:>20}",
"Scenario", "Size(B)", "Ops/sec", "Time(µs)", "Hit Rates"
);
}
_ => {}
}
println!("{}", "-".repeat(120));
@@ -1396,6 +1880,7 @@ fn run_benchmarks(c: &mut Criterion) {
bench_latency_distribution(c);
bench_scaling_characteristics(c);
bench_memory_efficiency(c);
bench_l1_cache_chat_template(c);
// Print summary at the end
print_summary();