router-grpc: Add tools processing and other paramters for apply_chat_template (#10877)
This commit is contained in:
@@ -27,7 +27,7 @@ use crate::tokenizer::traits::Tokenizer;
|
|||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
// Data structures for processing
|
// Data structures for processing
|
||||||
@@ -300,12 +300,87 @@ impl GrpcRouter {
|
|||||||
{
|
{
|
||||||
// Get content format and transform messages accordingly
|
// Get content format and transform messages accordingly
|
||||||
let content_format = hf_tokenizer.chat_template_content_format();
|
let content_format = hf_tokenizer.chat_template_content_format();
|
||||||
let transformed_messages =
|
let mut transformed_messages =
|
||||||
Self::transform_messages_for_content_format(&request.messages, content_format)?;
|
Self::process_content_format(&request.messages, content_format)?;
|
||||||
|
|
||||||
hf_tokenizer
|
// Process tool call arguments in assistant messages
|
||||||
.apply_chat_template(&transformed_messages, true)
|
Self::process_tool_call_arguments(&mut transformed_messages)?;
|
||||||
.map_err(|e| format!("Failed to apply chat template: {}", e))?
|
|
||||||
|
// Convert tools to JSON values for template processing
|
||||||
|
let tools_json: Option<Vec<serde_json::Value>> = request
|
||||||
|
.tools
|
||||||
|
.as_ref()
|
||||||
|
.map(|tools| {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(serde_json::to_value)
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
})
|
||||||
|
.transpose()
|
||||||
|
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
|
||||||
|
|
||||||
|
// Build template kwargs, merging reasoning_effort if present
|
||||||
|
let mut combined_template_kwargs = std::collections::HashMap::new();
|
||||||
|
|
||||||
|
// Add reasoning_effort if present (like Python does)
|
||||||
|
if let Some(reasoning_effort) = &request.reasoning_effort {
|
||||||
|
combined_template_kwargs.insert(
|
||||||
|
"reasoning_effort".to_string(),
|
||||||
|
serde_json::Value::String(reasoning_effort.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add any additional template kwargs from request
|
||||||
|
if let Some(template_kwargs) = &request.chat_template_kwargs {
|
||||||
|
for (key, value) in template_kwargs {
|
||||||
|
combined_template_kwargs.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_template_kwargs = if combined_template_kwargs.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(&combined_template_kwargs)
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
continue_final_message: request.continue_final_message,
|
||||||
|
tools: tools_json.as_deref(),
|
||||||
|
template_kwargs: final_template_kwargs,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle assistant prefix for continue_final_message
|
||||||
|
let assistant_prefix = if request.continue_final_message
|
||||||
|
&& !transformed_messages.is_empty()
|
||||||
|
&& transformed_messages
|
||||||
|
.last()
|
||||||
|
.and_then(|msg| msg.get("role"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
== Some("assistant")
|
||||||
|
{
|
||||||
|
// Pop the last message to handle it separately
|
||||||
|
let last_msg = transformed_messages.pop().unwrap();
|
||||||
|
last_msg
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply chat template with the (now possibly shorter) list of messages
|
||||||
|
let rendered = hf_tokenizer
|
||||||
|
.apply_chat_template(&transformed_messages, params)
|
||||||
|
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
|
||||||
|
|
||||||
|
// Append assistant prefix if we have one
|
||||||
|
if let Some(prefix) = assistant_prefix {
|
||||||
|
format!("{}{}", rendered, prefix)
|
||||||
|
} else {
|
||||||
|
rendered
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(
|
return Err(
|
||||||
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
|
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
|
||||||
@@ -322,8 +397,8 @@ impl GrpcRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transform messages based on content format for ANY message type
|
/// Process messages based on content format for ANY message type
|
||||||
fn transform_messages_for_content_format(
|
fn process_content_format(
|
||||||
messages: &[crate::protocols::spec::ChatMessage],
|
messages: &[crate::protocols::spec::ChatMessage],
|
||||||
content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat,
|
content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat,
|
||||||
) -> Result<Vec<serde_json::Value>, String> {
|
) -> Result<Vec<serde_json::Value>, String> {
|
||||||
@@ -394,6 +469,49 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Process tool call arguments in messages
|
||||||
|
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
|
||||||
|
fn process_tool_call_arguments(messages: &mut [serde_json::Value]) -> Result<(), String> {
|
||||||
|
for msg in messages {
|
||||||
|
// Early return if not assistant message
|
||||||
|
let role = msg.get("role").and_then(|v| v.as_str());
|
||||||
|
if role != Some("assistant") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Early return if no tool_calls
|
||||||
|
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut())
|
||||||
|
else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Process each tool call's arguments
|
||||||
|
for call in tool_calls {
|
||||||
|
let Some(function) = call.get_mut("function") else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let Some(args) = function.get_mut("arguments") else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let Some(args_str) = args.as_str() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse JSON string to object (like Python json.loads)
|
||||||
|
match serde_json::from_str::<serde_json::Value>(args_str) {
|
||||||
|
Ok(parsed) => *args = parsed,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(format!(
|
||||||
|
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
|
||||||
|
args_str, e
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Build gRPC SamplingParams from OpenAI request
|
/// Build gRPC SamplingParams from OpenAI request
|
||||||
fn build_grpc_sampling_params(
|
fn build_grpc_sampling_params(
|
||||||
&self,
|
&self,
|
||||||
@@ -410,6 +528,19 @@ impl GrpcRouter {
|
|||||||
.or(request.max_tokens)
|
.or(request.max_tokens)
|
||||||
.map(|v| v as i32);
|
.map(|v| v as i32);
|
||||||
|
|
||||||
|
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
|
||||||
|
let skip_special_tokens = if request.tools.is_some() {
|
||||||
|
match &request.tool_choice {
|
||||||
|
Some(crate::protocols::spec::ToolChoice::Value(
|
||||||
|
crate::protocols::spec::ToolChoiceValue::None,
|
||||||
|
)) => request.skip_special_tokens,
|
||||||
|
Some(_) => false, // tool_choice is not "none"
|
||||||
|
None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
request.skip_special_tokens
|
||||||
|
};
|
||||||
|
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
Ok(proto::SamplingParams {
|
Ok(proto::SamplingParams {
|
||||||
temperature: request.temperature.unwrap_or(1.0),
|
temperature: request.temperature.unwrap_or(1.0),
|
||||||
@@ -422,7 +553,7 @@ impl GrpcRouter {
|
|||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
||||||
skip_special_tokens: request.skip_special_tokens,
|
skip_special_tokens,
|
||||||
n: request.n.unwrap_or(1) as i32,
|
n: request.n.unwrap_or(1) as i32,
|
||||||
structural_tag: structural_tag.unwrap_or_default(),
|
structural_tag: structural_tag.unwrap_or_default(),
|
||||||
constraint: self.build_constraint(request)?,
|
constraint: self.build_constraint(request)?,
|
||||||
@@ -700,11 +831,9 @@ mod tests {
|
|||||||
name: None,
|
name: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
let transformed_message = &result[0];
|
let transformed_message = &result[0];
|
||||||
@@ -735,11 +864,9 @@ mod tests {
|
|||||||
name: None,
|
name: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||||
ChatTemplateContentFormat::OpenAI,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
let transformed_message = &result[0];
|
let transformed_message = &result[0];
|
||||||
@@ -764,11 +891,9 @@ mod tests {
|
|||||||
name: None,
|
name: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
let transformed_message = &result[0];
|
let transformed_message = &result[0];
|
||||||
@@ -791,11 +916,9 @@ mod tests {
|
|||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
let transformed_message = &result[0];
|
let transformed_message = &result[0];
|
||||||
@@ -832,11 +955,9 @@ mod tests {
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
|
|
||||||
@@ -862,11 +983,9 @@ mod tests {
|
|||||||
name: None,
|
name: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = GrpcRouter::transform_messages_for_content_format(
|
let result =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
let transformed_message = &result[0];
|
let transformed_message = &result[0];
|
||||||
@@ -902,22 +1021,18 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
// Test String format
|
// Test String format
|
||||||
let result_string = GrpcRouter::transform_messages_for_content_format(
|
let result_string =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
|
||||||
ChatTemplateContentFormat::String,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result_string.len(), 2);
|
assert_eq!(result_string.len(), 2);
|
||||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||||
|
|
||||||
// Test OpenAI format
|
// Test OpenAI format
|
||||||
let result_openai = GrpcRouter::transform_messages_for_content_format(
|
let result_openai =
|
||||||
&messages,
|
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
|
||||||
ChatTemplateContentFormat::OpenAI,
|
.unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result_openai.len(), 2);
|
assert_eq!(result_openai.len(), 2);
|
||||||
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
|
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use minijinja::{context, machinery, Environment, Value};
|
use minijinja::{context, machinery, Environment, Value};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// Chat template content format
|
/// Chat template content format
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@@ -288,21 +289,25 @@ fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
|
|||||||
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
|
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parameters for chat template application
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ChatTemplateParams<'a> {
|
||||||
|
pub add_generation_prompt: bool,
|
||||||
|
pub continue_final_message: bool,
|
||||||
|
pub tools: Option<&'a [serde_json::Value]>,
|
||||||
|
pub documents: Option<&'a [serde_json::Value]>,
|
||||||
|
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
|
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
|
||||||
pub struct ChatTemplateProcessor {
|
pub struct ChatTemplateProcessor {
|
||||||
template: String,
|
template: String,
|
||||||
bos_token: Option<String>,
|
|
||||||
eos_token: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatTemplateProcessor {
|
impl ChatTemplateProcessor {
|
||||||
/// Create a new chat template processor
|
/// Create a new chat template processor
|
||||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
pub fn new(template: String) -> Self {
|
||||||
ChatTemplateProcessor {
|
ChatTemplateProcessor { template }
|
||||||
template,
|
|
||||||
bos_token,
|
|
||||||
eos_token,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply the chat template to a list of messages
|
/// Apply the chat template to a list of messages
|
||||||
@@ -313,8 +318,12 @@ impl ChatTemplateProcessor {
|
|||||||
pub fn apply_chat_template(
|
pub fn apply_chat_template(
|
||||||
&self,
|
&self,
|
||||||
messages: &[serde_json::Value],
|
messages: &[serde_json::Value],
|
||||||
add_generation_prompt: bool,
|
params: ChatTemplateParams,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
// Validate incompatible options
|
||||||
|
if params.continue_final_message && params.add_generation_prompt {
|
||||||
|
return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."));
|
||||||
|
}
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
|
|
||||||
// Register the template
|
// Register the template
|
||||||
@@ -326,17 +335,29 @@ impl ChatTemplateProcessor {
|
|||||||
.get_template("chat")
|
.get_template("chat")
|
||||||
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
|
||||||
|
|
||||||
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
|
// Convert messages to minijinja::Value (messages already processed by router)
|
||||||
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
|
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
|
||||||
|
|
||||||
// Render the template directly with the provided values
|
let base_context = context! {
|
||||||
|
messages => &minijinja_messages,
|
||||||
|
add_generation_prompt => params.add_generation_prompt,
|
||||||
|
tools => params.tools,
|
||||||
|
documents => params.documents,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Merge with template_kwargs if provided
|
||||||
|
let ctx = if let Some(kwargs) = params.template_kwargs {
|
||||||
|
context! {
|
||||||
|
..base_context,
|
||||||
|
..Value::from_serialize(kwargs)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
base_context
|
||||||
|
};
|
||||||
|
|
||||||
|
// Render the template
|
||||||
let rendered = tmpl
|
let rendered = tmpl
|
||||||
.render(context! {
|
.render(&ctx)
|
||||||
messages => minijinja_messages,
|
|
||||||
add_generation_prompt => add_generation_prompt,
|
|
||||||
bos_token => self.bos_token.clone().unwrap_or_default(),
|
|
||||||
eos_token => self.eos_token.clone().unwrap_or_default()
|
|
||||||
})
|
|
||||||
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
|
||||||
|
|
||||||
Ok(rendered)
|
Ok(rendered)
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ use anyhow::{Error, Result};
|
|||||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||||
|
|
||||||
use super::chat_template::{
|
use super::chat_template::{
|
||||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||||
|
ChatTemplateProcessor,
|
||||||
};
|
};
|
||||||
use super::traits::{
|
use super::traits::{
|
||||||
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
|
||||||
@@ -165,16 +166,11 @@ impl HuggingFaceTokenizer {
|
|||||||
pub fn apply_chat_template(
|
pub fn apply_chat_template(
|
||||||
&self,
|
&self,
|
||||||
messages: &[serde_json::Value],
|
messages: &[serde_json::Value],
|
||||||
add_generation_prompt: bool,
|
params: ChatTemplateParams,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
if let Some(ref template) = self.chat_template {
|
if let Some(ref template) = self.chat_template {
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.clone());
|
||||||
template.clone(),
|
processor.apply_chat_template(messages, params)
|
||||||
self.special_tokens.bos_token.clone(),
|
|
||||||
self.special_tokens.eos_token.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
processor.apply_chat_template(messages, add_generation_prompt)
|
|
||||||
} else {
|
} else {
|
||||||
Err(Error::msg(
|
Err(Error::msg(
|
||||||
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
|
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use sglang_router_rs::protocols::spec;
|
use sglang_router_rs::protocols::spec;
|
||||||
use sglang_router_rs::tokenizer::chat_template::{
|
use sglang_router_rs::tokenizer::chat_template::{
|
||||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||||
|
ChatTemplateProcessor,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -169,11 +170,7 @@ assistant:
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
template.to_string(),
|
|
||||||
Some("<s>".to_string()),
|
|
||||||
Some("</s>".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
spec::ChatMessage::System {
|
spec::ChatMessage::System {
|
||||||
@@ -194,8 +191,12 @@ assistant:
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&message_values, true)
|
.apply_chat_template(&message_values, params)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(result.contains("system: You are helpful"));
|
assert!(result.contains("system: You are helpful"));
|
||||||
assert!(result.contains("user: Hello"));
|
assert!(result.contains("user: Hello"));
|
||||||
@@ -204,19 +205,15 @@ assistant:
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_with_tokens_unit_test() {
|
fn test_chat_template_with_tokens_unit_test() {
|
||||||
// Template that uses special tokens
|
// Template that uses template kwargs for tokens (more realistic)
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{{ bos_token }}
|
{%- if start_token -%}{{ start_token }}{%- endif -%}
|
||||||
{%- for message in messages -%}
|
{%- for message in messages -%}
|
||||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
{{ message.role }}: {{ message.content }}{%- if end_token -%}{{ end_token }}{%- endif -%}
|
||||||
{% endfor -%}
|
{% endfor -%}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
template.to_string(),
|
|
||||||
Some("<s>".to_string()),
|
|
||||||
Some("</s>".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let messages = [spec::ChatMessage::User {
|
let messages = [spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -230,8 +227,24 @@ fn test_chat_template_with_tokens_unit_test() {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Use template_kwargs to pass tokens
|
||||||
|
let mut template_kwargs = std::collections::HashMap::new();
|
||||||
|
template_kwargs.insert(
|
||||||
|
"start_token".to_string(),
|
||||||
|
serde_json::Value::String("<s>".to_string()),
|
||||||
|
);
|
||||||
|
template_kwargs.insert(
|
||||||
|
"end_token".to_string(),
|
||||||
|
serde_json::Value::String("</s>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
template_kwargs: Some(&template_kwargs),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&message_values, false)
|
.apply_chat_template(&message_values, params)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(result.contains("<s>"));
|
assert!(result.contains("<s>"));
|
||||||
assert!(result.contains("</s>"));
|
assert!(result.contains("</s>"));
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use sglang_router_rs::protocols::spec;
|
use sglang_router_rs::protocols::spec;
|
||||||
use sglang_router_rs::tokenizer::chat_template::{
|
use sglang_router_rs::tokenizer::chat_template::{
|
||||||
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
|
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
|
||||||
|
ChatTemplateProcessor,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -14,11 +15,7 @@ fn test_simple_chat_template() {
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
template.to_string(),
|
|
||||||
Some("<s>".to_string()),
|
|
||||||
Some("</s>".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let messages = [spec::ChatMessage::User {
|
let messages = [spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -32,8 +29,12 @@ fn test_simple_chat_template() {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&message_values, true)
|
.apply_chat_template(&message_values, params)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(result.contains("<|user|>Test<|end|>"));
|
assert!(result.contains("<|user|>Test<|end|>"));
|
||||||
assert!(result.contains("<|assistant|>"));
|
assert!(result.contains("<|assistant|>"));
|
||||||
@@ -41,19 +42,15 @@ fn test_simple_chat_template() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_with_tokens() {
|
fn test_chat_template_with_tokens() {
|
||||||
// Template that uses special tokens
|
// Template that uses template kwargs for tokens
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{{ bos_token }}
|
{%- if bos_token -%}{{ bos_token }}{%- endif -%}
|
||||||
{%- for message in messages -%}
|
{%- for message in messages -%}
|
||||||
{{ message.role }}: {{ message.content }}{{ eos_token }}
|
{{ message.role }}: {{ message.content }}{%- if eos_token -%}{{ eos_token }}{%- endif -%}
|
||||||
{% endfor -%}
|
{% endfor -%}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
template.to_string(),
|
|
||||||
Some("<s>".to_string()),
|
|
||||||
Some("</s>".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let messages = [spec::ChatMessage::User {
|
let messages = [spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -67,8 +64,24 @@ fn test_chat_template_with_tokens() {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Use template_kwargs to pass tokens
|
||||||
|
let mut template_kwargs = std::collections::HashMap::new();
|
||||||
|
template_kwargs.insert(
|
||||||
|
"bos_token".to_string(),
|
||||||
|
serde_json::Value::String("<s>".to_string()),
|
||||||
|
);
|
||||||
|
template_kwargs.insert(
|
||||||
|
"eos_token".to_string(),
|
||||||
|
serde_json::Value::String("</s>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
template_kwargs: Some(&template_kwargs),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&message_values, false)
|
.apply_chat_template(&message_values, params)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(result.contains("<s>"));
|
assert!(result.contains("<s>"));
|
||||||
assert!(result.contains("</s>"));
|
assert!(result.contains("</s>"));
|
||||||
@@ -85,7 +98,7 @@ fn test_llama_style_template() {
|
|||||||
{%- set system_message = '' -%}
|
{%- set system_message = '' -%}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
|
||||||
{{- bos_token }}
|
{{- bos_token if bos_token else '<|begin_of_text|>' }}
|
||||||
{%- if system_message %}
|
{%- if system_message %}
|
||||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
@@ -99,11 +112,7 @@ fn test_llama_style_template() {
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
template.to_string(),
|
|
||||||
Some("<|begin_of_text|>".to_string()),
|
|
||||||
Some("<|end_of_text|>".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
spec::ChatMessage::System {
|
spec::ChatMessage::System {
|
||||||
@@ -124,7 +133,21 @@ fn test_llama_style_template() {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
// Use template_kwargs to pass the token
|
||||||
|
let mut template_kwargs = std::collections::HashMap::new();
|
||||||
|
template_kwargs.insert(
|
||||||
|
"bos_token".to_string(),
|
||||||
|
serde_json::Value::String("<|begin_of_text|>".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
template_kwargs: Some(&template_kwargs),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let result = processor
|
||||||
|
.apply_chat_template(&json_messages, params)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Check that the result contains expected markers
|
// Check that the result contains expected markers
|
||||||
assert!(result.contains("<|begin_of_text|>"));
|
assert!(result.contains("<|begin_of_text|>"));
|
||||||
@@ -147,7 +170,7 @@ fn test_chatml_template() {
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
spec::ChatMessage::User {
|
spec::ChatMessage::User {
|
||||||
@@ -176,7 +199,15 @@ fn test_chatml_template() {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = processor.apply_chat_template(&json_messages, true).unwrap();
|
let result = processor
|
||||||
|
.apply_chat_template(
|
||||||
|
&json_messages,
|
||||||
|
ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Check ChatML format
|
// Check ChatML format
|
||||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||||
@@ -196,7 +227,7 @@ assistant:
|
|||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
|
|
||||||
let messages = [spec::ChatMessage::User {
|
let messages = [spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -212,12 +243,20 @@ assistant:
|
|||||||
|
|
||||||
// Test without generation prompt
|
// Test without generation prompt
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&json_messages, false)
|
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.trim(), "user: Test");
|
assert_eq!(result.trim(), "user: Test");
|
||||||
|
|
||||||
// Test with generation prompt
|
// Test with generation prompt
|
||||||
let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap();
|
let result_with_prompt = processor
|
||||||
|
.apply_chat_template(
|
||||||
|
&json_messages,
|
||||||
|
ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
assert!(result_with_prompt.contains("assistant:"));
|
assert!(result_with_prompt.contains("assistant:"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,10 +264,12 @@ assistant:
|
|||||||
fn test_empty_messages_template() {
|
fn test_empty_messages_template() {
|
||||||
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
|
|
||||||
let messages: Vec<serde_json::Value> = vec![];
|
let messages: Vec<serde_json::Value> = vec![];
|
||||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
let result = processor
|
||||||
|
.apply_chat_template(&messages, ChatTemplateParams::default())
|
||||||
|
.unwrap();
|
||||||
assert_eq!(result, "");
|
assert_eq!(result, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,7 +320,7 @@ fn test_template_with_multimodal_content() {
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||||
|
|
||||||
let messages = [spec::ChatMessage::User {
|
let messages = [spec::ChatMessage::User {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
@@ -304,7 +345,7 @@ fn test_template_with_multimodal_content() {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = processor
|
let result = processor
|
||||||
.apply_chat_template(&json_messages, false)
|
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Should contain both text and image parts
|
// Should contain both text and image parts
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use sglang_router_rs::protocols::spec;
|
use sglang_router_rs::protocols::spec;
|
||||||
|
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
|
||||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
@@ -79,7 +80,14 @@ mod tests {
|
|||||||
.map(|msg| serde_json::to_value(msg).unwrap())
|
.map(|msg| serde_json::to_value(msg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = tokenizer.apply_chat_template(&json_messages, true).unwrap();
|
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
|
||||||
|
let params = ChatTemplateParams {
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let result = tokenizer
|
||||||
|
.apply_chat_template(&json_messages, params)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Verify the custom template format
|
// Verify the custom template format
|
||||||
assert!(result.contains("<|user|>Hello"));
|
assert!(result.contains("<|user|>Hello"));
|
||||||
@@ -150,7 +158,7 @@ mod tests {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = tokenizer
|
let result = tokenizer
|
||||||
.apply_chat_template(&json_messages, false)
|
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Should use CUSTOM template, not built-in
|
// Should use CUSTOM template, not built-in
|
||||||
@@ -219,7 +227,7 @@ mod tests {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = tokenizer
|
let result = tokenizer
|
||||||
.apply_chat_template(&json_messages, false)
|
.apply_chat_template(&json_messages, ChatTemplateParams::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.starts_with("NEW:"));
|
assert!(result.starts_with("NEW:"));
|
||||||
|
|||||||
Reference in New Issue
Block a user