From 728af88781f8964db943ee3e6f14a5bec6145284 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 13 Oct 2025 13:47:57 -0400 Subject: [PATCH] [router] allow user to specify chat template path (#11549) --- .../py_src/sglang_router/router_args.py | 9 +- sgl-router/src/config/types.rs | 6 ++ sgl-router/src/lib.rs | 5 + sgl-router/src/main.rs | 4 + sgl-router/src/server.rs | 50 ++++++---- sgl-router/src/tokenizer/factory.rs | 92 +++++++++++++++++-- sgl-router/src/tokenizer/mod.rs | 5 +- sgl-router/tests/api_endpoints_test.rs | 4 + sgl-router/tests/request_formats_test.rs | 1 + sgl-router/tests/responses_api_test.rs | 10 ++ sgl-router/tests/streaming_tests.rs | 1 + sgl-router/tests/test_openai_routing.rs | 3 + sgl-router/tests/test_pd_routing.rs | 1 + 13 files changed, 159 insertions(+), 32 deletions(-) diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py index f3f3b8391..6d68535b5 100644 --- a/sgl-router/py_src/sglang_router/router_args.py +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -83,10 +83,9 @@ class RouterArgs: cb_timeout_duration_secs: int = 60 cb_window_duration_secs: int = 120 disable_circuit_breaker: bool = False - # Tokenizer configuration model_path: Optional[str] = None tokenizer_path: Optional[str] = None - # Parser configuration + chat_template: Optional[str] = None reasoning_parser: Optional[str] = None tool_call_parser: Optional[str] = None @@ -449,6 +448,12 @@ class RouterArgs: default=None, help="Explicit tokenizer path (overrides model_path tokenizer if provided)", ) + parser.add_argument( + f"--{prefix}chat-template", + type=str, + default=None, + help="Chat template path (optional)", + ) parser.add_argument( f"--{prefix}reasoning-parser", type=str, diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index f55f14b79..cdb972092 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -67,6 +67,8 @@ pub struct RouterConfig { pub model_path: Option, /// Explicit tokenizer path (overrides model_path tokenizer if provided) pub tokenizer_path: Option, + /// Chat template path (optional) + pub chat_template: Option, /// History backend configuration (memory or none, default: memory) #[serde(default = "default_history_backend")] pub history_backend: HistoryBackend, @@ -450,6 +452,7 @@ impl Default for RouterConfig { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + chat_template: None, history_backend: default_history_backend(), oracle: None, reasoning_parser: None, @@ -994,6 +997,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + chat_template: None, history_backend: default_history_backend(), oracle: None, reasoning_parser: None, @@ -1061,6 +1065,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + chat_template: None, history_backend: default_history_backend(), oracle: None, reasoning_parser: None, @@ -1124,6 +1129,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + chat_template: None, history_backend: default_history_backend(), oracle: None, reasoning_parser: None, diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 01b037ed1..92dd1950f 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -90,6 +90,7 @@ struct Router { connection_mode: config::ConnectionMode, model_path: Option, tokenizer_path: Option, + chat_template: Option, reasoning_parser: Option, tool_call_parser: Option, } @@ -216,6 +217,7 @@ impl Router { enable_igw: self.enable_igw, model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), + chat_template: self.chat_template.clone(), history_backend: config::HistoryBackend::Memory, oracle: None, reasoning_parser: self.reasoning_parser.clone(), @@ -284,6 +286,7 @@ impl Router { rate_limit_tokens_per_second = None, model_path = None, tokenizer_path = None, + chat_template = None, reasoning_parser = None, tool_call_parser = None, ))] @@ -345,6 +348,7 @@ impl Router { rate_limit_tokens_per_second: Option, model_path: Option, tokenizer_path: Option, + chat_template: Option, reasoning_parser: Option, tool_call_parser: Option, ) -> PyResult { @@ -420,6 +424,7 @@ impl Router { connection_mode, model_path, tokenizer_path, + chat_template, reasoning_parser, tool_call_parser, }) diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 60f422f5d..95910046d 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -255,6 +255,9 @@ struct CliArgs { #[arg(long)] tokenizer_path: Option, + #[arg(long)] + chat_template: Option, + #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] history_backend: String, @@ -561,6 +564,7 @@ impl CliArgs { rate_limit_tokens_per_second: None, model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), + chat_template: self.chat_template.clone(), history_backend, oracle, reasoning_parser: self.reasoning_parser.clone(), diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index dacd88a5d..d1d634e6b 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -82,28 +82,40 @@ impl AppContext { } }; - let (tokenizer, reasoning_parser_factory, tool_parser_factory) = - if router_config.connection_mode == ConnectionMode::Grpc { - let tokenizer_path = router_config - .tokenizer_path - .clone() - .or_else(|| router_config.model_path.clone()) - .ok_or_else(|| { - "gRPC mode requires either --tokenizer-path or --model-path to be specified" - .to_string() - })?; + let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if router_config + .connection_mode + == ConnectionMode::Grpc + { + let tokenizer_path = router_config + .tokenizer_path + .clone() + .or_else(|| router_config.model_path.clone()) + .ok_or_else(|| { + "gRPC mode requires either --tokenizer-path or --model-path to be specified" + .to_string() + })?; - let tokenizer = Some( - tokenizer_factory::create_tokenizer(&tokenizer_path) - .map_err(|e| format!("Failed to create tokenizer: {e}"))?, + let tokenizer = Some( + tokenizer_factory::create_tokenizer_with_chat_template_blocking( + &tokenizer_path, + router_config.chat_template.as_deref(), + ) + .map_err(|e| { + format!( + "Failed to create tokenizer from '{}': {}. \ + Ensure the path is valid and points to a tokenizer file (tokenizer.json) \ + or a HuggingFace model ID. For directories, ensure they contain tokenizer files.", + tokenizer_path, e + ) + })?, ); - let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); - let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); + let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); + let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); - (tokenizer, reasoning_parser_factory, tool_parser_factory) - } else { - (None, None, None) - }; + (tokenizer, reasoning_parser_factory, tool_parser_factory) + } else { + (None, None, None) + }; let worker_registry = Arc::new(WorkerRegistry::new()); let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index c46ed7282..e3bebba3b 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -4,6 +4,7 @@ use std::fs::File; use std::io::Read; use std::path::Path; use std::sync::Arc; +use tracing::{debug, info}; use super::huggingface::HuggingFaceTokenizer; use super::tiktoken::TiktokenTokenizer; @@ -189,14 +190,57 @@ pub fn discover_chat_template_in_dir(dir: &Path) -> Option { None } +/// Helper function to resolve and log chat template selection +/// +/// Resolves the final chat template to use by prioritizing provided path over auto-discovery, +/// and logs the source for debugging purposes. +fn resolve_and_log_chat_template( + provided_path: Option<&str>, + discovery_dir: &Path, + model_name: &str, +) -> Option { + let final_chat_template = provided_path + .map(|s| s.to_string()) + .or_else(|| discover_chat_template_in_dir(discovery_dir)); + + match (&provided_path, &final_chat_template) { + (Some(provided), _) => { + info!("Using provided chat template: {}", provided); + } + (None, Some(discovered)) => { + info!( + "Auto-discovered chat template in '{}': {}", + discovery_dir.display(), + discovered + ); + } + (None, None) => { + debug!( + "No chat template provided or discovered for model: {}", + model_name + ); + } + } + + final_chat_template +} + /// Factory function to create tokenizer from a model name or path (async version) pub async fn create_tokenizer_async( model_name_or_path: &str, +) -> Result> { + create_tokenizer_async_with_chat_template(model_name_or_path, None).await +} + +/// Factory function to create tokenizer with optional chat template (async version) +pub async fn create_tokenizer_async_with_chat_template( + model_name_or_path: &str, + chat_template_path: Option<&str>, ) -> Result> { // Check if it's a file path let path = Path::new(model_name_or_path); if path.exists() { - return create_tokenizer_from_file(model_name_or_path); + return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path); } // Check if it's a GPT model name that should use Tiktoken @@ -216,8 +260,13 @@ pub async fn create_tokenizer_async( // Look for tokenizer.json in the cache directory let tokenizer_path = cache_dir.join("tokenizer.json"); if tokenizer_path.exists() { - // Try to find a chat template file in the cache directory - let chat_template_path = discover_chat_template_in_dir(&cache_dir); + // Resolve chat template: provided path takes precedence over auto-discovery + let final_chat_template = resolve_and_log_chat_template( + chat_template_path, + &cache_dir, + model_name_or_path, + ); + let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| { Error::msg(format!( "Tokenizer path is not valid UTF-8: {:?}", @@ -226,7 +275,7 @@ pub async fn create_tokenizer_async( })?; create_tokenizer_with_chat_template( tokenizer_path_str, - chat_template_path.as_deref(), + final_chat_template.as_deref(), ) } else { // Try other common tokenizer file names @@ -234,13 +283,19 @@ pub async fn create_tokenizer_async( for file_name in &possible_files { let file_path = cache_dir.join(file_name); if file_path.exists() { - let chat_template_path = discover_chat_template_in_dir(&cache_dir); + // Resolve chat template: provided path takes precedence over auto-discovery + let final_chat_template = resolve_and_log_chat_template( + chat_template_path, + &cache_dir, + model_name_or_path, + ); + let file_path_str = file_path.to_str().ok_or_else(|| { Error::msg(format!("File path is not valid UTF-8: {:?}", file_path)) })?; return create_tokenizer_with_chat_template( file_path_str, - chat_template_path.as_deref(), + final_chat_template.as_deref(), ); } } @@ -258,11 +313,22 @@ pub async fn create_tokenizer_async( } /// Factory function to create tokenizer from a model name or path (blocking version) +/// +/// This delegates to `create_tokenizer_with_chat_template_blocking` with no chat template, +/// which handles both local files and HuggingFace Hub downloads uniformly. pub fn create_tokenizer(model_name_or_path: &str) -> Result> { + create_tokenizer_with_chat_template_blocking(model_name_or_path, None) +} + +/// Factory function to create tokenizer with optional chat template (blocking version) +pub fn create_tokenizer_with_chat_template_blocking( + model_name_or_path: &str, + chat_template_path: Option<&str>, +) -> Result> { // Check if it's a file path let path = Path::new(model_name_or_path); if path.exists() { - return create_tokenizer_from_file(model_name_or_path); + return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path); } // Check if it's a GPT model name that should use Tiktoken @@ -280,11 +346,19 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result) -> Self { // Create default router config let config = RouterConfig { + chat_template: None, mode: RoutingMode::Regular { worker_urls: vec![], }, @@ -1365,6 +1366,7 @@ mod error_tests { async fn test_payload_too_large() { // Create context with small payload limit let config = RouterConfig { + chat_template: None, mode: RoutingMode::Regular { worker_urls: vec![], }, @@ -1723,6 +1725,7 @@ mod pd_mode_tests { .unwrap_or(9000); let config = RouterConfig { + chat_template: None, mode: RoutingMode::PrefillDecode { prefill_urls: vec![(prefill_url, Some(prefill_port))], decode_urls: vec![decode_url], @@ -1888,6 +1891,7 @@ mod request_id_tests { async fn test_request_id_with_custom_headers() { // Create config with custom request ID headers let config = RouterConfig { + chat_template: None, mode: RoutingMode::Regular { worker_urls: vec![], }, diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index c2eb6a9bd..589be6171 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -18,6 +18,7 @@ struct TestContext { impl TestContext { async fn new(worker_configs: Vec) -> Self { let mut config = RouterConfig { + chat_template: None, mode: RoutingMode::Regular { worker_urls: vec![], }, diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index c0239af46..9d641d795 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -44,6 +44,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { // Build router config (HTTP OpenAI mode) let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec![worker_url], }, @@ -245,6 +246,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { async fn test_conversations_crud_basic() { // Router in OpenAI mode (no actual upstream calls in these tests) let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, @@ -576,6 +578,7 @@ async fn test_multi_turn_loop_with_mcp() { // Build router config let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec![worker_url], }, @@ -753,6 +756,7 @@ async fn test_max_tool_calls_limit() { let worker_url = worker.start().await.expect("start worker"); let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec![worker_url], }, @@ -896,6 +900,7 @@ async fn setup_streaming_mcp_test() -> ( let worker_url = worker.start().await.expect("start worker"); let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec![worker_url], }, @@ -1338,6 +1343,7 @@ async fn test_streaming_multi_turn_with_mcp() { async fn test_conversation_items_create_and_get() { // Test creating items and getting a specific item let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, @@ -1440,6 +1446,7 @@ async fn test_conversation_items_create_and_get() { async fn test_conversation_items_delete() { // Test deleting an item from a conversation let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, @@ -1548,6 +1555,7 @@ async fn test_conversation_items_delete() { async fn test_conversation_items_max_limit() { // Test that creating > 20 items returns error let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, @@ -1626,6 +1634,7 @@ async fn test_conversation_items_max_limit() { async fn test_conversation_items_unsupported_type() { // Test that unsupported item types return error let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, @@ -1703,6 +1712,7 @@ async fn test_conversation_items_unsupported_type() { async fn test_conversation_items_multi_conversation_sharing() { // Test that items can be shared across conversations via soft delete let router_cfg = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["http://localhost".to_string()], }, diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index b658f001a..5e1bcf876 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -19,6 +19,7 @@ struct TestContext { impl TestContext { async fn new(worker_configs: Vec) -> Self { let mut config = RouterConfig { + chat_template: None, mode: RoutingMode::Regular { worker_urls: vec![], }, diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index b68a3f9bb..56f6f64f1 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -867,6 +867,7 @@ async fn test_openai_router_models_auth_forwarding() { #[test] fn oracle_config_validation_requires_config_when_enabled() { let config = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["https://api.openai.com".to_string()], }, @@ -891,6 +892,7 @@ fn oracle_config_validation_requires_config_when_enabled() { #[test] fn oracle_config_validation_accepts_dsn_only() { let config = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["https://api.openai.com".to_string()], }, @@ -913,6 +915,7 @@ fn oracle_config_validation_accepts_dsn_only() { #[test] fn oracle_config_validation_accepts_wallet_alias() { let config = RouterConfig { + chat_template: None, mode: RoutingMode::OpenAI { worker_urls: vec!["https://api.openai.com".to_string()], }, diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 9d99f100f..5ca3a6c1e 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -164,6 +164,7 @@ mod test_pd_routing { for (mode, policy) in test_cases { let config = RouterConfig { + chat_template: None, mode, policy, host: "127.0.0.1".to_string(),