[router] add tokenizer chat template support (#9370)

Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
Simo Lin
2025-08-19 20:14:02 -07:00
committed by GitHub
parent 7638f5e44e
commit 5fbad308cd
12 changed files with 748 additions and 85 deletions

View File

@@ -0,0 +1,156 @@
#[cfg(test)]
mod tests {
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
#[test]
#[cfg(feature = "huggingface")]
fn test_chat_message_helpers() {
let system_msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(system_msg.role, "system");
assert_eq!(system_msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
assert_eq!(user_msg.content, "Hello!");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
assert_eq!(assistant_msg.content, "Hi there!");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_llama_style_template() {
// Test a Llama-style chat template
let template = r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<|begin_of_text|>".to_string()),
Some("<|end_of_text|>".to_string()),
);
let messages = vec![
ChatMessage::system("You are a helpful assistant"),
ChatMessage::user("What is 2+2?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check that the result contains expected markers
assert!(result.contains("<|begin_of_text|>"));
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
assert!(result.contains("You are a helpful assistant"));
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("What is 2+2?"));
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_chatml_template() {
// Test a ChatML-style template
let template = r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check ChatML format
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_without_generation_prompt() {
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![ChatMessage::user("Test")];
// Test without generation prompt
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result.trim(), "user: Test");
// Test with generation prompt
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
assert!(result_with_prompt.contains("assistant:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_with_special_tokens() {
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Hello")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "<s>Hello</s>");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_empty_messages() {
let template =
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "");
}
// Integration test with actual tokenizer file loading would go here
// but requires a real tokenizer_config.json file
}

View File

@@ -0,0 +1,186 @@
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::TempDir;
#[test]
#[cfg(feature = "huggingface")]
fn test_load_chat_template_from_file() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
let template_path = temp_dir.path().join("template.jinja");
// Write a test template
let template_content = r#"
{%- for message in messages %}
{{- '<|' + message['role'] + '|>' + message['content'] }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|assistant|>' }}
{%- endif %}
"#;
fs::write(&template_path, template_content).unwrap();
// Create a mock tokenizer config
let tokenizer_config = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"hello": 0,
"world": 1,
"<s>": 2,
"</s>": 3
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_config).unwrap();
// Load tokenizer with custom chat template
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(template_path.to_str().unwrap()),
)
.unwrap();
// Test that the custom template is used
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there"),
];
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
// Verify the custom template format
assert!(result.contains("<|user|>Hello"));
assert!(result.contains("<|assistant|>Hi there"));
assert!(result.ends_with("<|assistant|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_override_existing_template() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
// Create tokenizer config with a built-in template
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
let config_with_template = r#"{
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
}"#;
fs::write(&tokenizer_config_path, config_with_template).unwrap();
// Create the actual tokenizer file
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Create custom template that should override
let custom_template_path = temp_dir.path().join("custom.jinja");
let custom_template =
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
fs::write(&custom_template_path, custom_template).unwrap();
// Load with custom template - should override the built-in one
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(custom_template_path.to_str().unwrap()),
)
.unwrap();
let messages = vec![ChatMessage::user("Test")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
// Should use CUSTOM template, not built-in
assert!(result.starts_with("CUSTOM:"));
assert!(result.contains("[user]: Test"));
assert!(!result.contains("built-in:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_set_chat_template_after_creation() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory and tokenizer file
let temp_dir = TempDir::new().unwrap();
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Load tokenizer without custom template
let mut tokenizer =
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
// Set a template after creation (mimics Python's behavior)
let new_template =
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
tokenizer.set_chat_template(new_template.to_string());
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
assert!(result.starts_with("NEW:"));
assert!(result.contains("user: Hello;"));
assert!(result.contains("assistant: World;"));
}
}