From 5519766a4da33da9b1817e181120b715c63a9864 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 27 Sep 2025 23:54:12 -0400 Subject: [PATCH] [router] fix chat template loading and tokenizer path (#10999) --- python/sglang/srt/entrypoints/grpc_server.py | 6 +- sgl-router/src/tokenizer/factory.rs | 52 +++- sgl-router/src/tokenizer/hub.rs | 96 +++++++- sgl-router/src/tokenizer/huggingface.rs | 31 ++- sgl-router/tests/tokenizer_integration.rs | 242 +++++++++++++++++++ 5 files changed, 418 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index 3c5bf5436..b360c5068 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -404,7 +404,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) logprob_start_len=grpc_req.logprob_start_len or -1, top_logprobs_num=grpc_req.top_logprobs_num or 0, stream=grpc_req.stream or False, - lora_path=grpc_req.lora_id if grpc_req.lora_id else None, + lora_id=grpc_req.lora_id if grpc_req.lora_id else None, token_ids_logprob=( list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None ), @@ -458,9 +458,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) repetition_penalty=grpc_params.repetition_penalty or 1.0, max_new_tokens=grpc_params.max_new_tokens or 128, min_new_tokens=grpc_params.min_new_tokens or 0, - stop=list(grpc_params.stop) if grpc_params.stop else None, + stop=list(grpc_params.stop) if grpc_params.stop else [], stop_token_ids=( - list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None + list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else [] ), skip_special_tokens=grpc_params.skip_special_tokens, spaces_between_special_tokens=grpc_params.spaces_between_special_tokens, diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index 9c1544fe7..6544f12b0 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -134,6 +134,36 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool { || buffer.windows(4).any(|w| w == b"")) } +/// Helper function to discover chat template files in a directory +pub fn discover_chat_template_in_dir(dir: &Path) -> Option { + use std::fs; + + // Priority 1: Look for chat_template.json (contains Jinja in JSON format) + let json_template_path = dir.join("chat_template.json"); + if json_template_path.exists() { + return json_template_path.to_str().map(|s| s.to_string()); + } + + // Priority 2: Look for chat_template.jinja (standard Jinja file) + let jinja_path = dir.join("chat_template.jinja"); + if jinja_path.exists() { + return jinja_path.to_str().map(|s| s.to_string()); + } + + // Priority 3: Look for any .jinja file (for models with non-standard naming) + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".jinja") && name != "chat_template.jinja" { + return entry.path().to_str().map(|s| s.to_string()); + } + } + } + } + + None +} + /// Factory function to create tokenizer from a model name or path (async version) pub async fn create_tokenizer_async( model_name_or_path: &str, @@ -161,14 +191,32 @@ pub async fn create_tokenizer_async( // Look for tokenizer.json in the cache directory let tokenizer_path = cache_dir.join("tokenizer.json"); if tokenizer_path.exists() { - create_tokenizer_from_file(tokenizer_path.to_str().unwrap()) + // Try to find a chat template file in the cache directory + let chat_template_path = discover_chat_template_in_dir(&cache_dir); + let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| { + Error::msg(format!( + "Tokenizer path is not valid UTF-8: {:?}", + tokenizer_path + )) + })?; + create_tokenizer_with_chat_template( + tokenizer_path_str, + chat_template_path.as_deref(), + ) } else { // Try other common tokenizer file names let possible_files = ["tokenizer_config.json", "vocab.json"]; for file_name in &possible_files { let file_path = cache_dir.join(file_name); if file_path.exists() { - return create_tokenizer_from_file(file_path.to_str().unwrap()); + let chat_template_path = discover_chat_template_in_dir(&cache_dir); + let file_path_str = file_path.to_str().ok_or_else(|| { + Error::msg(format!("File path is not valid UTF-8: {:?}", file_path)) + })?; + return create_tokenizer_with_chat_template( + file_path_str, + chat_template_path.as_deref(), + ); } } Err(Error::msg(format!( diff --git a/sgl-router/src/tokenizer/hub.rs b/sgl-router/src/tokenizer/hub.rs index c9d2cd1a4..f9c344f57 100644 --- a/sgl-router/src/tokenizer/hub.rs +++ b/sgl-router/src/tokenizer/hub.rs @@ -40,6 +40,13 @@ fn is_tokenizer_file(filename: &str) -> bool { || filename.ends_with("merges.txt") || filename.ends_with(".model") // SentencePiece models || filename.ends_with(".tiktoken") + || is_chat_template_file(filename) // Include chat template files +} + +/// Checks if a file is a chat template file +fn is_chat_template_file(filename: &str) -> bool { + filename.ends_with(".jinja") // Direct Jinja files + || filename == "chat_template.json" // JSON file containing Jinja template } /// Attempt to download tokenizer files from Hugging Face @@ -123,7 +130,13 @@ pub async fn download_tokenizer_from_hf(model_id: impl AsRef) -> anyhow::R } match cache_dir { - Some(dir) => Ok(dir), + Some(dir) => { + // Ensure we return the correct model directory, not a subfolder + // Some models have an "original" subfolder for PyTorch weights + // We want the main model directory that contains tokenizer files + let final_dir = resolve_model_cache_dir(&dir, &model_name); + Ok(final_dir) + } None => Err(anyhow::anyhow!( "Invalid HF cache path for model '{}'", model_name @@ -206,11 +219,76 @@ pub async fn from_hf(name: impl AsRef, ignore_weights: bool) -> anyhow::Re } match p.parent() { - Some(p) => Ok(p.to_path_buf()), + Some(p) => { + let final_dir = resolve_model_cache_dir(p, &model_name); + Ok(final_dir) + } None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())), } } +/// Resolve the correct model cache directory +/// Handles cases where files might be in subfolders (e.g., "original" folder) +fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf { + // Check if we're in a subfolder like "original" + if let Some(parent) = path.parent() { + if let Some(folder_name) = path.file_name() { + if folder_name == "original" { + // We're in the "original" subfolder, go up one level + return parent.to_path_buf(); + } + } + } + + // Check if the current path contains the model name components + // This helps ensure we're at the right directory level + let model_parts: Vec<&str> = model_name.split('/').collect(); + if model_parts.len() >= 2 { + let expected_pattern = format!( + "models--{}--{}", + model_parts[0].replace("-", "--"), + model_parts[1].replace("-", "--") + ); + + if path.to_string_lossy().contains(&expected_pattern) { + // We're already at the correct level + return path.to_path_buf(); + } + + let mut current = path.to_path_buf(); + + // First check if current path already contains tokenizer files + if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists() + { + return current; + } + + // If not, traverse up to find the model root, then look in snapshots + while let Some(parent) = current.parent() { + if parent.to_string_lossy().contains(&expected_pattern) { + let snapshots_dir = parent.join("snapshots"); + if snapshots_dir.exists() && snapshots_dir.is_dir() { + if let Ok(entries) = std::fs::read_dir(&snapshots_dir) { + for entry in entries.flatten() { + let snapshot_path = entry.path(); + if snapshot_path.is_dir() + && (snapshot_path.join("tokenizer.json").exists() + || snapshot_path.join("tokenizer_config.json").exists()) + { + return snapshot_path; + } + } + } + } + return parent.to_path_buf(); + } + current = parent.to_path_buf(); + } + } + + path.to_path_buf() +} + #[cfg(test)] mod tests { use super::*; @@ -223,10 +301,24 @@ mod tests { assert!(is_tokenizer_file("vocab.json")); assert!(is_tokenizer_file("merges.txt")); assert!(is_tokenizer_file("spiece.model")); + assert!(is_tokenizer_file("chat_template.jinja")); + assert!(is_tokenizer_file("template.jinja")); assert!(!is_tokenizer_file("model.bin")); assert!(!is_tokenizer_file("README.md")); } + #[test] + fn test_is_chat_template_file() { + assert!(is_chat_template_file("chat_template.jinja")); + assert!(is_chat_template_file("template.jinja")); + assert!(is_chat_template_file("any_file.jinja")); + assert!(is_chat_template_file("chat_template.json")); + assert!(!is_chat_template_file("tokenizer.json")); + assert!(!is_chat_template_file("other_file.json")); + assert!(!is_chat_template_file("chat_template")); + assert!(!is_chat_template_file("README.md")); + } + #[test] fn test_is_weight_file() { assert!(is_weight_file("model.bin")); diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index 396ccdf60..beaf98eb7 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -25,7 +25,12 @@ pub struct HuggingFaceTokenizer { impl HuggingFaceTokenizer { /// Create a tokenizer from a HuggingFace tokenizer JSON file pub fn from_file(file_path: &str) -> Result { - Self::from_file_with_chat_template(file_path, None) + // Try to auto-discover chat template if not explicitly provided + let path = std::path::Path::new(file_path); + let chat_template_path = path + .parent() + .and_then(crate::tokenizer::factory::discover_chat_template_in_dir); + Self::from_file_with_chat_template(file_path, chat_template_path.as_deref()) } /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template @@ -135,13 +140,35 @@ impl HuggingFaceTokenizer { None } - /// Load chat template from a .jinja file + /// Load chat template from a file (.jinja or .json containing Jinja) fn load_chat_template_from_file(template_path: &str) -> Result> { use std::fs; let content = fs::read_to_string(template_path) .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?; + // Check if it's a JSON file containing a Jinja template + if template_path.ends_with(".json") { + // Parse JSON and extract the template string + let json_value: serde_json::Value = serde_json::from_str(&content) + .map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?; + + if let Some(template_str) = json_value.as_str() { + return Ok(Some(template_str.to_string())); + } else if let Some(obj) = json_value.as_object() { + if let Some(template_value) = obj.get("chat_template") { + if let Some(template_str) = template_value.as_str() { + return Ok(Some(template_str.to_string())); + } + } + } + + return Err(Error::msg( + "chat_template.json does not contain a valid template", + )); + } + + // Otherwise it's a plain .jinja file // Clean up the template (similar to Python implementation) let template = content.trim().replace("\\n", "\n"); diff --git a/sgl-router/tests/tokenizer_integration.rs b/sgl-router/tests/tokenizer_integration.rs index 5bd6d56c2..6e4a87ea9 100644 --- a/sgl-router/tests/tokenizer_integration.rs +++ b/sgl-router/tests/tokenizer_integration.rs @@ -314,3 +314,245 @@ fn test_thread_safety() { handle.join().expect("Thread panicked"); } } + +#[test] +fn test_chat_template_discovery() { + use std::fs; + use tempfile::TempDir; + + // Create a temporary directory with test files + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let dir_path = temp_dir.path(); + + // Copy a real tokenizer.json file for testing + // We'll use the TinyLlama tokenizer that's already cached + let cached_tokenizer = ensure_tokenizer_cached(); + let tokenizer_path = dir_path.join("tokenizer.json"); + fs::copy(&cached_tokenizer, &tokenizer_path).expect("Failed to copy tokenizer file"); + + // Test 1: With chat_template.jinja file + let jinja_path = dir_path.join("chat_template.jinja"); + fs::write(&jinja_path, "{{ messages }}").expect("Failed to write chat template"); + + let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()); + assert!( + tokenizer.is_ok(), + "Should load tokenizer with chat template" + ); + + // Clean up for next test + fs::remove_file(&jinja_path).ok(); + + // Test 2: With tokenizer_config.json containing chat_template + let config_path = dir_path.join("tokenizer_config.json"); + fs::write(&config_path, r#"{"chat_template": "{{ messages }}"}"#) + .expect("Failed to write config"); + + let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()); + assert!( + tokenizer.is_ok(), + "Should load tokenizer with embedded template" + ); + + // Test 3: No chat template + fs::remove_file(&config_path).ok(); + let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()); + assert!( + tokenizer.is_ok(), + "Should load tokenizer without chat template" + ); +} + +#[test] +fn test_load_chat_template_from_local_file() { + use std::fs; + use tempfile::TempDir; + + // Test 1: Load tokenizer with explicit chat template path + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let dir_path = temp_dir.path(); + + // Copy a real tokenizer for testing + let cached_tokenizer = ensure_tokenizer_cached(); + let tokenizer_path = dir_path.join("tokenizer.json"); + fs::copy(&cached_tokenizer, &tokenizer_path).expect("Failed to copy tokenizer"); + + // Create a chat template file + let template_path = dir_path.join("my_template.jinja"); + let template_content = r#"{% for message in messages %}{{ message.role }}: {{ message.content }} +{% endfor %}"#; + fs::write(&template_path, template_content).expect("Failed to write template"); + + // Load tokenizer with explicit template path + let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template( + tokenizer_path.to_str().unwrap(), + Some(template_path.to_str().unwrap()), + ); + assert!( + tokenizer.is_ok(), + "Should load tokenizer with explicit template path" + ); +} + +#[tokio::test] +async fn test_tinyllama_embedded_template() { + use sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf; + + // Skip in CI without HF_TOKEN + + // Test 2: TinyLlama has chat template embedded in tokenizer_config.json + match download_tokenizer_from_hf("TinyLlama/TinyLlama-1.1B-Chat-v1.0").await { + Ok(cache_dir) => { + // Verify tokenizer_config.json exists + let config_path = cache_dir.join("tokenizer_config.json"); + assert!(config_path.exists(), "tokenizer_config.json should exist"); + + // Load the config and check for chat_template + let config_content = + std::fs::read_to_string(&config_path).expect("Failed to read config"); + assert!( + config_content.contains("\"chat_template\""), + "TinyLlama should have embedded chat_template in config" + ); + + // Load tokenizer and verify it has chat template + let tokenizer_path = cache_dir.join("tokenizer.json"); + let _tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load tokenizer"); + + println!( + "✓ TinyLlama: Loaded tokenizer with embedded template from tokenizer_config.json" + ); + } + Err(e) => { + println!("Download test skipped due to error: {}", e); + } + } +} + +#[tokio::test] +async fn test_qwen3_next_embedded_template() { + use sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf; + + // Test 3: Qwen3-Next has chat template in tokenizer_config.json + match download_tokenizer_from_hf("Qwen/Qwen3-Next-80B-A3B-Instruct").await { + Ok(cache_dir) => { + let config_path = cache_dir.join("tokenizer_config.json"); + assert!(config_path.exists(), "tokenizer_config.json should exist"); + + // Verify chat_template in config + let config_content = + std::fs::read_to_string(&config_path).expect("Failed to read config"); + assert!( + config_content.contains("\"chat_template\""), + "Qwen3-Next should have chat_template in tokenizer_config.json" + ); + + // Load tokenizer + let tokenizer_path = cache_dir.join("tokenizer.json"); + if tokenizer_path.exists() { + let _tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load tokenizer"); + println!("✓ Qwen3-Next: Loaded tokenizer with embedded template"); + } + } + Err(e) => { + println!("Download test skipped due to error: {}", e); + } + } +} + +#[tokio::test] +async fn test_qwen3_vl_json_template_priority() { + use sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf; + + // Test 4: Qwen3-VL has both tokenizer_config.json template and chat_template.json + // Should prioritize chat_template.json + match download_tokenizer_from_hf("Qwen/Qwen3-VL-235B-A22B-Instruct").await { + Ok(cache_dir) => { + // Check for chat_template.json + let json_template_path = cache_dir.join("chat_template.json"); + let has_json_template = json_template_path.exists(); + + // Also check tokenizer_config.json + let config_path = cache_dir.join("tokenizer_config.json"); + assert!(config_path.exists(), "tokenizer_config.json should exist"); + + if has_json_template { + let json_content = std::fs::read_to_string(&json_template_path) + .expect("Failed to read chat_template.json"); + println!("✓ Qwen3-VL: Found chat_template.json (should be prioritized)"); + + // Verify it contains jinja template + assert!( + !json_content.is_empty(), + "chat_template.json should contain template" + ); + } + + // Load tokenizer - it should use the appropriate template + let tokenizer_path = cache_dir.join("tokenizer.json"); + if tokenizer_path.exists() { + let _tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()) + .expect("Failed to load tokenizer"); + println!("✓ Qwen3-VL: Loaded tokenizer with template priority handling"); + } + } + Err(e) => { + println!("Download test skipped due to error: {}", e); + } + } +} + +#[tokio::test] +async fn test_llava_separate_jinja_template() { + use sglang_router_rs::tokenizer::hub::download_tokenizer_from_hf; + + // Test 5: llava has chat_template.jinja as a separate file, not in tokenizer_config.json + match download_tokenizer_from_hf("llava-hf/llava-1.5-7b-hf").await { + Ok(cache_dir) => { + // Check for .jinja file + let jinja_path = cache_dir.join("chat_template.jinja"); + let has_jinja = jinja_path.exists() + || std::fs::read_dir(&cache_dir) + .map(|entries| { + entries.filter_map(|e| e.ok()).any(|e| { + e.file_name() + .to_str() + .is_some_and(|name| name.ends_with(".jinja")) + }) + }) + .unwrap_or(false); + + if has_jinja { + println!("✓ llava: Found separate .jinja chat template file"); + } + + // Check tokenizer_config.json - should NOT have embedded template + let config_path = cache_dir.join("tokenizer_config.json"); + if config_path.exists() { + let config_content = + std::fs::read_to_string(&config_path).expect("Failed to read config"); + + // llava might not have chat_template in config + if !config_content.contains("\"chat_template\"") { + println!("✓ llava: No embedded template in config (as expected)"); + } + } + + // Load tokenizer - should auto-discover the .jinja file + let tokenizer_path = cache_dir.join("tokenizer.json"); + if tokenizer_path.exists() { + let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()); + if tokenizer.is_ok() { + println!("✓ llava: Loaded tokenizer with auto-discovered .jinja template"); + } else { + println!("Note: llava tokenizer loading failed - might need specific handling"); + } + } + } + Err(e) => { + println!("Download test skipped due to error: {}", e); + } + } +}