[router] add tokenizer benchmark (#9427)
This commit is contained in:
@@ -73,12 +73,18 @@ tower = { version = "0.5", features = ["util"] }
|
||||
http-body-util = "0.1"
|
||||
portpicker = "0.1"
|
||||
tempfile = "3.8"
|
||||
lazy_static = "1.4"
|
||||
|
||||
[[bench]]
|
||||
name = "request_processing"
|
||||
harness = false
|
||||
path = "benches/request_processing.rs"
|
||||
|
||||
[[bench]]
|
||||
name = "tokenizer_benchmark"
|
||||
harness = false
|
||||
path = "benches/tokenizer_benchmark.rs"
|
||||
|
||||
[profile.release]
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
|
||||
1400
sgl-router/benches/tokenizer_benchmark.rs
Normal file
1400
sgl-router/benches/tokenizer_benchmark.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,6 @@
|
||||
// Mock worker for testing - these functions are used by integration tests
|
||||
#![allow(dead_code)]
|
||||
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
@@ -25,7 +28,6 @@ pub struct MockWorkerConfig {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum WorkerType {
|
||||
Regular,
|
||||
Prefill,
|
||||
@@ -33,7 +35,6 @@ pub enum WorkerType {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum HealthStatus {
|
||||
Healthy,
|
||||
Unhealthy,
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
// These modules are used by tests and benchmarks
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod mock_worker;
|
||||
pub mod test_app;
|
||||
|
||||
use sglang_router_rs::config::RouterConfig;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
/// Helper function to create AppContext for tests
|
||||
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
@@ -13,3 +18,80 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
config.max_concurrent_requests,
|
||||
))
|
||||
}
|
||||
|
||||
// Tokenizer download configuration
|
||||
const TINYLLAMA_TOKENIZER_URL: &str =
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
|
||||
const CACHE_DIR: &str = ".tokenizer_cache";
|
||||
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
|
||||
|
||||
// Global mutex to prevent concurrent downloads
|
||||
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
/// Downloads the TinyLlama tokenizer from HuggingFace if not already cached.
|
||||
/// Returns the path to the cached tokenizer file.
|
||||
///
|
||||
/// This function is thread-safe and will only download the tokenizer once
|
||||
/// even if called from multiple threads concurrently.
|
||||
pub fn ensure_tokenizer_cached() -> PathBuf {
|
||||
// Get or initialize the mutex
|
||||
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
|
||||
|
||||
// Lock to ensure only one thread downloads at a time
|
||||
let _guard = mutex.lock().unwrap();
|
||||
|
||||
let cache_dir = PathBuf::from(CACHE_DIR);
|
||||
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
|
||||
}
|
||||
|
||||
// Download tokenizer if not already cached
|
||||
if !tokenizer_path.exists() {
|
||||
println!("Downloading TinyLlama tokenizer from HuggingFace...");
|
||||
|
||||
// Use blocking reqwest client since we're in tests/benchmarks
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(TINYLLAMA_TOKENIZER_URL)
|
||||
.send()
|
||||
.expect("Failed to download tokenizer");
|
||||
|
||||
if !response.status().is_success() {
|
||||
panic!("Failed to download tokenizer: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||
|
||||
// Verify we got actual JSON content
|
||||
if content.len() < 100 {
|
||||
panic!("Downloaded content too small: {} bytes", content.len());
|
||||
}
|
||||
|
||||
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
|
||||
println!(
|
||||
"Tokenizer downloaded and cached successfully ({} bytes)",
|
||||
tokenizer_path.metadata().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
tokenizer_path
|
||||
}
|
||||
|
||||
/// Common test prompts for consistency across tests
|
||||
pub const TEST_PROMPTS: [&str; 4] = [
|
||||
"deep learning is",
|
||||
"Deep learning is",
|
||||
"has anyone seen nemo lately",
|
||||
"another prompt",
|
||||
];
|
||||
|
||||
/// Pre-computed hashes for verification
|
||||
pub const EXPECTED_HASHES: [u64; 4] = [
|
||||
1209591529327510910,
|
||||
4181375434596349981,
|
||||
6245658446118930933,
|
||||
5097285695902185237,
|
||||
];
|
||||
|
||||
@@ -3,20 +3,14 @@
|
||||
//! These tests download the TinyLlama tokenizer from HuggingFace to verify our tokenizer
|
||||
//! implementation works correctly with real-world tokenizer files.
|
||||
|
||||
mod common;
|
||||
use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS};
|
||||
|
||||
use sglang_router_rs::tokenizer::{
|
||||
factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream,
|
||||
traits::*,
|
||||
};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
const TEST_PROMPTS: [&str; 4] = [
|
||||
"deep learning is",
|
||||
"Deep learning is",
|
||||
"has anyone seen nemo lately",
|
||||
"another prompt",
|
||||
];
|
||||
use std::sync::Arc;
|
||||
|
||||
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
|
||||
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
|
||||
@@ -34,70 +28,6 @@ const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
|
||||
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
|
||||
];
|
||||
|
||||
const TINYLLAMA_TOKENIZER_URL: &str =
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
|
||||
const CACHE_DIR: &str = ".tokenizer_cache";
|
||||
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
|
||||
|
||||
// Global mutex to prevent concurrent downloads
|
||||
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
// Pre-computed hashes for verification
|
||||
const EXPECTED_HASHES: [u64; 4] = [
|
||||
1209591529327510910,
|
||||
4181375434596349981,
|
||||
6245658446118930933,
|
||||
5097285695902185237,
|
||||
];
|
||||
|
||||
/// Downloads the tokenizer from HuggingFace if not already cached
|
||||
fn ensure_tokenizer_cached() -> PathBuf {
|
||||
// Get or initialize the mutex
|
||||
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
|
||||
|
||||
// Lock to ensure only one thread downloads at a time
|
||||
let _guard = mutex.lock().unwrap();
|
||||
|
||||
let cache_dir = PathBuf::from(CACHE_DIR);
|
||||
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
|
||||
}
|
||||
|
||||
// Download tokenizer if not already cached
|
||||
if !tokenizer_path.exists() {
|
||||
println!("Downloading TinyLlama tokenizer from HuggingFace...");
|
||||
|
||||
// Use blocking reqwest client since we're in tests
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(TINYLLAMA_TOKENIZER_URL)
|
||||
.send()
|
||||
.expect("Failed to download tokenizer");
|
||||
|
||||
if !response.status().is_success() {
|
||||
panic!("Failed to download tokenizer: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||
|
||||
// Verify we got actual JSON content
|
||||
if content.len() < 100 {
|
||||
panic!("Downloaded content too small: {} bytes", content.len());
|
||||
}
|
||||
|
||||
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
|
||||
println!(
|
||||
"Tokenizer downloaded and cached successfully ({} bytes)",
|
||||
tokenizer_path.metadata().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
tokenizer_path
|
||||
}
|
||||
|
||||
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
|
||||
prompts
|
||||
.iter()
|
||||
|
||||
Reference in New Issue
Block a user