router-grpc: Support jinja chat template content format detection (#10832)
This commit is contained in:
238
sgl-router/tests/chat_template_format_detection.rs
Normal file
238
sgl-router/tests/chat_template_format_detection.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_detect_string_format_deepseek() {
|
||||
// DeepSeek style template - expects string content
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_llama4() {
|
||||
// Llama4 style template - expects structured content
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message['content'] is iterable %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'text' %}
|
||||
{{ content['text'] }}
|
||||
{%- elif content['type'] == 'image' %}
|
||||
<image>
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- else %}
|
||||
{{ message['content'] }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_dot_notation() {
|
||||
// Template using dot notation
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- for part in message.content %}
|
||||
{%- if part.type == 'text' %}
|
||||
{{ part.text }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_variable_assignment() {
|
||||
// Template that assigns content to variable then iterates
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- set content = message['content'] %}
|
||||
{%- if content is sequence %}
|
||||
{%- for item in content %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_glm4v_style() {
|
||||
// GLM4V uses 'msg' instead of 'message'
|
||||
let template = r#"
|
||||
{%- for msg in messages %}
|
||||
{%- for part in msg.content %}
|
||||
{%- if part.type == 'text' %}{{ part.text }}{%- endif %}
|
||||
{%- if part.type == 'image' %}<image>{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_with_length_check() {
|
||||
// Template that checks content length
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message.content|length > 0 %}
|
||||
{%- for item in message.content %}
|
||||
{{ item.text }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openai_format_with_index_access() {
|
||||
// Template that accesses content by index
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{%- if message.content[0] %}
|
||||
First item: {{ message.content[0].text }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
"#;
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_template_defaults_to_string() {
|
||||
let template = "Not a valid {% jinja template";
|
||||
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_template_defaults_to_string() {
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(""),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template_unit_test() {
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt %}
|
||||
assistant:
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are helpful".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, true)
|
||||
.unwrap();
|
||||
assert!(result.contains("system: You are helpful"));
|
||||
assert!(result.contains("user: Hello"));
|
||||
assert!(result.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens_unit_test() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, false)
|
||||
.unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
314
sgl-router/tests/chat_template_integration.rs
Normal file
314
sgl-router/tests/chat_template_integration.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_simple_chat_template() {
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
<|{{ message.role }}|>{{ message.content }}<|end|>
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt %}
|
||||
<|assistant|>
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, true)
|
||||
.unwrap();
|
||||
assert!(result.contains("<|user|>Test<|end|>"));
|
||||
assert!(result.contains("<|assistant|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values like the router does
|
||||
let message_values: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, false)
|
||||
.unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("What is 2+2?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("How are you?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages_template() {
|
||||
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages: Vec<serde_json::Value> = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_format_detection() {
|
||||
// Test string format detection
|
||||
let string_template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{%- endfor -%}
|
||||
"#;
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(string_template),
|
||||
ChatTemplateContentFormat::String
|
||||
);
|
||||
|
||||
// Test OpenAI format detection
|
||||
let openai_template = r#"
|
||||
{%- for message in messages -%}
|
||||
{%- for content in message.content -%}
|
||||
{{ content.type }}: {{ content.text }}
|
||||
{%- endfor -%}
|
||||
{%- endfor -%}
|
||||
"#;
|
||||
assert_eq!(
|
||||
detect_chat_template_content_format(openai_template),
|
||||
ChatTemplateContentFormat::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_with_multimodal_content() {
|
||||
// Test that multimodal messages work correctly when serialized to JSON
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{ message.role }}:
|
||||
{%- if message.content is string %}
|
||||
{{ message.content }}
|
||||
{%- else %}
|
||||
{%- for part in message.content %}
|
||||
{%- if part.type == "text" %}
|
||||
{{ part.text }}
|
||||
{%- elif part.type == "image_url" %}
|
||||
[IMAGE]
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% endfor %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Parts(vec![
|
||||
spec::ContentPart::Text {
|
||||
text: "Look at this:".to_string(),
|
||||
},
|
||||
spec::ContentPart::ImageUrl {
|
||||
image_url: spec::ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
// Should contain both text and image parts
|
||||
assert!(result.contains("user:"));
|
||||
assert!(result.contains("Look at this:"));
|
||||
assert!(result.contains("[IMAGE]"));
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_load_chat_template_from_file() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let template_path = temp_dir.path().join("template.jinja");
|
||||
@@ -59,11 +58,28 @@ mod tests {
|
||||
|
||||
// Test that the custom template is used
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there"),
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
|
||||
// Convert to JSON values like the router does
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer.apply_chat_template(&json_messages, true).unwrap();
|
||||
|
||||
// Verify the custom template format
|
||||
assert!(result.contains("<|user|>Hello"));
|
||||
@@ -73,9 +89,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_override_existing_template() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
@@ -124,8 +137,21 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
// Should use CUSTOM template, not built-in
|
||||
assert!(result.starts_with("CUSTOM:"));
|
||||
@@ -135,9 +161,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_set_chat_template_after_creation() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory and tokenizer file
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tokenizer_json = r#"{
|
||||
@@ -173,8 +196,31 @@ mod tests {
|
||||
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||
tokenizer.set_chat_template(new_template.to_string());
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("World".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Convert to JSON values
|
||||
let json_messages: Vec<serde_json::Value> = messages
|
||||
.iter()
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = tokenizer
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("NEW:"));
|
||||
assert!(result.contains("user: Hello;"));
|
||||
@@ -1,150 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_helpers() {
|
||||
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(system_msg.role, "system");
|
||||
assert_eq!(system_msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
assert_eq!(user_msg.content, "Hello!");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
assert_eq!(assistant_msg.content, "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant"),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("How are you?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_with_special_tokens() {
|
||||
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "<s>Hello</s>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let template =
|
||||
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
// Integration test with actual tokenizer file loading would go here
|
||||
// but requires a real tokenizer_config.json file
|
||||
}
|
||||
Reference in New Issue
Block a user