[router] add tokenizer benchmark (#9427)
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
// These modules are used by tests and benchmarks
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod mock_worker;
|
||||
pub mod test_app;
|
||||
|
||||
use sglang_router_rs::config::RouterConfig;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
/// Helper function to create AppContext for tests
|
||||
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
@@ -13,3 +18,80 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
config.max_concurrent_requests,
|
||||
))
|
||||
}
|
||||
|
||||
// Tokenizer download configuration
|
||||
const TINYLLAMA_TOKENIZER_URL: &str =
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
|
||||
const CACHE_DIR: &str = ".tokenizer_cache";
|
||||
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
|
||||
|
||||
// Global mutex to prevent concurrent downloads
|
||||
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
/// Downloads the TinyLlama tokenizer from HuggingFace if not already cached.
|
||||
/// Returns the path to the cached tokenizer file.
|
||||
///
|
||||
/// This function is thread-safe and will only download the tokenizer once
|
||||
/// even if called from multiple threads concurrently.
|
||||
pub fn ensure_tokenizer_cached() -> PathBuf {
|
||||
// Get or initialize the mutex
|
||||
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
|
||||
|
||||
// Lock to ensure only one thread downloads at a time
|
||||
let _guard = mutex.lock().unwrap();
|
||||
|
||||
let cache_dir = PathBuf::from(CACHE_DIR);
|
||||
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
|
||||
}
|
||||
|
||||
// Download tokenizer if not already cached
|
||||
if !tokenizer_path.exists() {
|
||||
println!("Downloading TinyLlama tokenizer from HuggingFace...");
|
||||
|
||||
// Use blocking reqwest client since we're in tests/benchmarks
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(TINYLLAMA_TOKENIZER_URL)
|
||||
.send()
|
||||
.expect("Failed to download tokenizer");
|
||||
|
||||
if !response.status().is_success() {
|
||||
panic!("Failed to download tokenizer: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||
|
||||
// Verify we got actual JSON content
|
||||
if content.len() < 100 {
|
||||
panic!("Downloaded content too small: {} bytes", content.len());
|
||||
}
|
||||
|
||||
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
|
||||
println!(
|
||||
"Tokenizer downloaded and cached successfully ({} bytes)",
|
||||
tokenizer_path.metadata().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
tokenizer_path
|
||||
}
|
||||
|
||||
/// Common test prompts for consistency across tests
|
||||
pub const TEST_PROMPTS: [&str; 4] = [
|
||||
"deep learning is",
|
||||
"Deep learning is",
|
||||
"has anyone seen nemo lately",
|
||||
"another prompt",
|
||||
];
|
||||
|
||||
/// Pre-computed hashes for verification
|
||||
pub const EXPECTED_HASHES: [u64; 4] = [
|
||||
1209591529327510910,
|
||||
4181375434596349981,
|
||||
6245658446118930933,
|
||||
5097285695902185237,
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user