[router] add tokenizer benchmark (#9427)
This commit is contained in:
@@ -3,20 +3,14 @@
|
||||
//! These tests download the TinyLlama tokenizer from HuggingFace to verify our tokenizer
|
||||
//! implementation works correctly with real-world tokenizer files.
|
||||
|
||||
mod common;
|
||||
use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS};
|
||||
|
||||
use sglang_router_rs::tokenizer::{
|
||||
factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream,
|
||||
traits::*,
|
||||
};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
const TEST_PROMPTS: [&str; 4] = [
|
||||
"deep learning is",
|
||||
"Deep learning is",
|
||||
"has anyone seen nemo lately",
|
||||
"another prompt",
|
||||
];
|
||||
use std::sync::Arc;
|
||||
|
||||
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
|
||||
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
|
||||
@@ -34,70 +28,6 @@ const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
|
||||
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
|
||||
];
|
||||
|
||||
const TINYLLAMA_TOKENIZER_URL: &str =
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
|
||||
const CACHE_DIR: &str = ".tokenizer_cache";
|
||||
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
|
||||
|
||||
// Global mutex to prevent concurrent downloads
|
||||
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
// Pre-computed hashes for verification
|
||||
const EXPECTED_HASHES: [u64; 4] = [
|
||||
1209591529327510910,
|
||||
4181375434596349981,
|
||||
6245658446118930933,
|
||||
5097285695902185237,
|
||||
];
|
||||
|
||||
/// Downloads the tokenizer from HuggingFace if not already cached
|
||||
fn ensure_tokenizer_cached() -> PathBuf {
|
||||
// Get or initialize the mutex
|
||||
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
|
||||
|
||||
// Lock to ensure only one thread downloads at a time
|
||||
let _guard = mutex.lock().unwrap();
|
||||
|
||||
let cache_dir = PathBuf::from(CACHE_DIR);
|
||||
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
|
||||
}
|
||||
|
||||
// Download tokenizer if not already cached
|
||||
if !tokenizer_path.exists() {
|
||||
println!("Downloading TinyLlama tokenizer from HuggingFace...");
|
||||
|
||||
// Use blocking reqwest client since we're in tests
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(TINYLLAMA_TOKENIZER_URL)
|
||||
.send()
|
||||
.expect("Failed to download tokenizer");
|
||||
|
||||
if !response.status().is_success() {
|
||||
panic!("Failed to download tokenizer: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||
|
||||
// Verify we got actual JSON content
|
||||
if content.len() < 100 {
|
||||
panic!("Downloaded content too small: {} bytes", content.len());
|
||||
}
|
||||
|
||||
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
|
||||
println!(
|
||||
"Tokenizer downloaded and cached successfully ({} bytes)",
|
||||
tokenizer_path.metadata().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
tokenizer_path
|
||||
}
|
||||
|
||||
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
|
||||
prompts
|
||||
.iter()
|
||||
|
||||
Reference in New Issue
Block a user