router-grpc: Add tools processing and other paramters for apply_chat_template (#10877)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use sglang_router_rs::protocols::spec;
|
||||
use sglang_router_rs::tokenizer::chat_template::{
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||
ChatTemplateProcessor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -14,11 +15,7 @@ fn test_simple_chat_template() {
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
@@ -32,8 +29,12 @@ fn test_simple_chat_template() {
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, true)
|
||||
.apply_chat_template(&message_values, params)
|
||||
.unwrap();
|
||||
assert!(result.contains("<|user|>Test<|end|>"));
|
||||
assert!(result.contains("<|assistant|>"));
|
||||
@@ -41,19 +42,15 @@ fn test_simple_chat_template() {
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_tokens() {
|
||||
// Template that uses special tokens
|
||||
// Template that uses template kwargs for tokens
|
||||
let template = r#"
|
||||
{{ bos_token }}
|
||||
{%- if bos_token -%}{{ bos_token }}{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
||||
{{ message.role }}: {{ message.content }}{%- if eos_token -%}{{ eos_token }}{%- endif -%}
|
||||
{% endfor -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
@@ -67,8 +64,24 @@ fn test_chat_template_with_tokens() {
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
// Use template_kwargs to pass tokens
|
||||
let mut template_kwargs = std::collections::HashMap::new();
|
||||
template_kwargs.insert(
|
||||
"bos_token".to_string(),
|
||||
serde_json::Value::String("<s>".to_string()),
|
||||
);
|
||||
template_kwargs.insert(
|
||||
"eos_token".to_string(),
|
||||
serde_json::Value::String("</s>".to_string()),
|
||||
);
|
||||
|
||||
let params = ChatTemplateParams {
|
||||
template_kwargs: Some(&template_kwargs),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&message_values, false)
|
||||
.apply_chat_template(&message_values, params)
|
||||
.unwrap();
|
||||
assert!(result.contains("<s>"));
|
||||
assert!(result.contains("</s>"));
|
||||
@@ -85,7 +98,7 @@ fn test_llama_style_template() {
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{{- bos_token if bos_token else '<|begin_of_text|>' }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
@@ -99,11 +112,7 @@ fn test_llama_style_template() {
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::System {
|
||||
@@ -124,7 +133,21 @@ fn test_llama_style_template() {
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
// Use template_kwargs to pass the token
|
||||
let mut template_kwargs = std::collections::HashMap::new();
|
||||
template_kwargs.insert(
|
||||
"bos_token".to_string(),
|
||||
serde_json::Value::String("<|begin_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
template_kwargs: Some(&template_kwargs),
|
||||
..Default::default()
|
||||
};
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, params)
|
||||
.unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
@@ -147,7 +170,7 @@ fn test_chatml_template() {
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
@@ -176,7 +199,15 @@ fn test_chatml_template() {
|
||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||
.collect();
|
||||
|
||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
let result = processor
|
||||
.apply_chat_template(
|
||||
&json_messages,
|
||||
ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
@@ -196,7 +227,7 @@ assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
@@ -212,12 +243,20 @@ assistant:
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||
.unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap();
|
||||
let result_with_prompt = processor
|
||||
.apply_chat_template(
|
||||
&json_messages,
|
||||
ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
@@ -225,10 +264,12 @@ assistant:
|
||||
fn test_empty_messages_template() {
|
||||
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages: Vec<serde_json::Value> = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
let result = processor
|
||||
.apply_chat_template(&messages, ChatTemplateParams::default())
|
||||
.unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
@@ -279,7 +320,7 @@ fn test_template_with_multimodal_content() {
|
||||
{% endfor %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
@@ -304,7 +345,7 @@ fn test_template_with_multimodal_content() {
|
||||
.collect();
|
||||
|
||||
let result = processor
|
||||
.apply_chat_template(&json_messages, false)
|
||||
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||
.unwrap();
|
||||
|
||||
// Should contain both text and image parts
|
||||
|
||||
Reference in New Issue
Block a user