diff --git a/sgl-router/src/tokenizer/chat_template.rs b/sgl-router/src/tokenizer/chat_template.rs index dec38cf59..e82544ca4 100644 --- a/sgl-router/src/tokenizer/chat_template.rs +++ b/sgl-router/src/tokenizer/chat_template.rs @@ -4,7 +4,8 @@ //! similar to HuggingFace transformers' apply_chat_template method. use anyhow::{anyhow, Result}; -use minijinja::{context, machinery, Environment, Value}; +use minijinja::machinery::ast::{Expr, Stmt}; +use minijinja::{context, Environment, Value}; use serde_json; use std::collections::HashMap; @@ -50,13 +51,277 @@ pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateConten ChatTemplateContentFormat::String } +/// Flags tracking which OpenAI-style patterns we've seen +#[derive(Default, Debug, Clone, Copy)] +struct Flags { + saw_iteration: bool, + saw_structure: bool, + saw_assignment: bool, + saw_macro: bool, +} + +impl Flags { + fn any(self) -> bool { + self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro + } +} + +/// Single-pass AST detector with scope tracking +struct Detector<'a> { + ast: &'a Stmt<'a>, + /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`) + scope: std::collections::VecDeque, + scope_set: std::collections::HashSet, + flags: Flags, +} + +impl<'a> Detector<'a> { + fn new(ast: &'a Stmt<'a>) -> Self { + Self { + ast, + scope: std::collections::VecDeque::new(), + scope_set: std::collections::HashSet::new(), + flags: Flags::default(), + } + } + + fn run(mut self) -> Flags { + self.walk_stmt(self.ast); + self.flags + } + + fn push_scope(&mut self, var: String) { + self.scope.push_back(var.clone()); + self.scope_set.insert(var); + } + + fn pop_scope(&mut self) { + if let Some(v) = self.scope.pop_back() { + self.scope_set.remove(&v); + } + } + + fn is_var_access(expr: &Expr, varname: &str) -> bool { + matches!(expr, Expr::Var(v) if v.id == varname) + } + + fn is_const_str(expr: &Expr, value: &str) -> bool { + matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value)) + } + + fn is_numeric_const(expr: &Expr) -> bool { + matches!(expr, Expr::Const(c) if c.value.is_number()) + } + + /// Check if expr is varname.content or varname["content"] + fn is_var_dot_content(expr: &Expr, varname: &str) -> bool { + match expr { + Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content", + Expr::GetItem(g) => { + Self::is_var_access(&g.expr, varname) + && Self::is_const_str(&g.subscript_expr, "content") + } + // Unwrap filters/tests that just wrap the same expr + Expr::Filter(f) => f + .expr + .as_ref() + .is_some_and(|e| Self::is_var_dot_content(e, varname)), + Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname), + _ => false, + } + } + + /// Check if expr accesses .content on any variable in our scope, or any descendant of it. + fn is_any_scope_var_content(&self, expr: &Expr) -> bool { + let mut current_expr = expr; + loop { + // Check if current level matches .content + if self + .scope_set + .iter() + .any(|v| Self::is_var_dot_content(current_expr, v)) + { + return true; + } + // Walk up the expression tree + match current_expr { + Expr::GetAttr(g) => current_expr = &g.expr, + Expr::GetItem(g) => current_expr = &g.expr, + _ => return false, + } + } + } + + fn walk_stmt(&mut self, stmt: &Stmt) { + // Early exit if we've already detected an OpenAI pattern + if self.flags.any() { + return; + } + + match stmt { + Stmt::Template(t) => { + for ch in &t.children { + self.walk_stmt(ch); + } + } + // {% for message in messages %} + Stmt::ForLoop(fl) => { + // Detect "for X in messages" → push X into scope + if let Expr::Var(iter) = &fl.iter { + if iter.id == "messages" { + if let Expr::Var(target) = &fl.target { + self.push_scope(target.id.to_string()); + } + } + } + + // Also detect "for ... in message.content" or "for ... in content" + // - Iterating directly over .content => OpenAI style + if self.is_any_scope_var_content(&fl.iter) { + self.flags.saw_iteration = true; + } + // - Iterating over a local var named "content" + if matches!(&fl.iter, Expr::Var(v) if v.id == "content") { + self.flags.saw_iteration = true; + } + + for b in &fl.body { + self.walk_stmt(b); + } + + // Pop scope if we pushed it + if let Expr::Var(iter) = &fl.iter { + if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) { + self.pop_scope(); + } + } + } + Stmt::IfCond(ic) => { + self.inspect_expr_for_structure(&ic.expr); + for b in &ic.true_body { + self.walk_stmt(b); + } + for b in &ic.false_body { + self.walk_stmt(b); + } + } + Stmt::EmitExpr(e) => { + self.inspect_expr_for_structure(&e.expr); + } + // {% set content = message.content %} + Stmt::Set(s) => { + if Self::is_var_access(&s.target, "content") + && self.is_any_scope_var_content(&s.expr) + { + self.flags.saw_assignment = true; + } + } + Stmt::Macro(m) => { + // Heuristic: macro that checks type (via `is` test) and also has any loop + let mut has_type_check = false; + let mut has_loop = false; + Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop); + if has_type_check && has_loop { + self.flags.saw_macro = true; + } + } + _ => {} + } + } + + fn inspect_expr_for_structure(&mut self, expr: &Expr) { + if self.flags.saw_structure { + return; + } + + match expr { + // content[0] or message.content[0] + Expr::GetItem(gi) => { + if (matches!(&gi.expr, Expr::Var(v) if v.id == "content") + || self.is_any_scope_var_content(&gi.expr)) + && Self::is_numeric_const(&gi.subscript_expr) + { + self.flags.saw_structure = true; + } + } + // content|length or message.content|length + Expr::Filter(f) => { + if f.name == "length" { + if let Some(inner) = &f.expr { + // Box derefs automatically, so `&**inner` is `&Expr` + let inner_ref: &Expr = inner; + let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content"); + if is_content_var || self.is_any_scope_var_content(inner_ref) { + self.flags.saw_structure = true; + } + } + } else if let Some(inner) = &f.expr { + let inner_ref: &Expr = inner; + self.inspect_expr_for_structure(inner_ref); + } + } + // content is sequence/iterable OR message.content is sequence/iterable + Expr::Test(t) => { + if t.name == "sequence" || t.name == "iterable" || t.name == "string" { + if matches!(&t.expr, Expr::Var(v) if v.id == "content") + || self.is_any_scope_var_content(&t.expr) + { + self.flags.saw_structure = true; + } + } else { + self.inspect_expr_for_structure(&t.expr); + } + } + Expr::GetAttr(g) => { + // Keep walking; nested expressions can hide structure checks + self.inspect_expr_for_structure(&g.expr); + } + // Handle binary operations like: if (message.content is string) and other_cond + Expr::BinOp(op) => { + self.inspect_expr_for_structure(&op.left); + self.inspect_expr_for_structure(&op.right); + } + // Handle unary operations like: if not (message.content is string) + Expr::UnaryOp(op) => { + self.inspect_expr_for_structure(&op.expr); + } + _ => {} + } + } + + fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) { + for s in body { + if *has_type_check && *has_loop { + return; + } + + match s { + Stmt::IfCond(ic) => { + if matches!(&ic.expr, Expr::Test(_)) { + *has_type_check = true; + } + Self::scan_macro_body(&ic.true_body, has_type_check, has_loop); + Self::scan_macro_body(&ic.false_body, has_type_check, has_loop); + } + Stmt::ForLoop(fl) => { + *has_loop = true; + Self::scan_macro_body(&fl.body, has_type_check, has_loop); + } + Stmt::Template(t) => { + Self::scan_macro_body(&t.children, has_type_check, has_loop); + } + _ => {} + } + } + } +} + /// AST-based detection using minijinja's unstable machinery -/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions +/// Single-pass detector with scope tracking fn detect_format_with_ast(template: &str) -> Option { use minijinja::machinery::{parse, WhitespaceConfig}; use minijinja::syntax::SyntaxConfig; - // Parse the template into AST let ast = match parse( template, "template", @@ -67,226 +332,12 @@ fn detect_format_with_ast(template: &str) -> Option { Err(_) => return Some(ChatTemplateContentFormat::String), }; - // Traverse AST looking for patterns that indicate OpenAI format - let has_iteration = find_content_iteration_in_ast(&ast); - let has_structure_checks = find_content_structure_checks_in_ast(&ast); - let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast); - - if has_iteration || has_structure_checks || has_assignment_patterns { - Some(ChatTemplateContentFormat::OpenAI) + let flags = Detector::new(&ast).run(); + Some(if flags.any() { + ChatTemplateContentFormat::OpenAI } else { - Some(ChatTemplateContentFormat::String) - } -} - -/// Find content iteration patterns in AST -/// Implements the same logic as SGLang's AST traversal -fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool { - use machinery::ast::Stmt; - - match ast { - Stmt::Template(template) => { - // Recursively check all children - template - .children - .iter() - .any(|child| find_content_iteration_in_ast(child)) - } - Stmt::ForLoop(for_loop) => { - // Check if this for-loop iterates over message content - is_var_or_elems_access(&for_loop.iter, "message", "content") || - is_var_or_elems_access(&for_loop.iter, "msg", "content") || - is_var_or_elems_access(&for_loop.iter, "m", "content") || - // Also check the body for nested loops - for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt)) - } - Stmt::IfCond(if_cond) => { - // Check true and false branches - if_cond - .true_body - .iter() - .any(|stmt| find_content_iteration_in_ast(stmt)) - || if_cond - .false_body - .iter() - .any(|stmt| find_content_iteration_in_ast(stmt)) - } - _ => false, // Other statement types don't contain loops - } -} - -/// Check if expression accesses varname['key'] or varname.key -/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes -fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool { - use machinery::ast::Expr; - - match expr { - // Check for attribute access: varname.key - Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key, - // Check for item access: varname['key'] or varname["key"] - Expr::GetItem(getitem) => { - is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key) - } - // Handle filters and tests that might wrap the access - Expr::Filter(filter) => { - if let Some(ref expr) = filter.expr { - is_var_or_elems_access(expr, varname, key) - } else { - false - } - } - Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key), - _ => false, - } -} - -/// Check if expression is a variable access (like {{ varname }}) -/// Implements SGLang's _is_var_access logic -fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool { - matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname) -} - -/// Check if expression is a constant string with the given value -fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool { - matches!(expr, machinery::ast::Expr::Const(const_expr) - if const_expr.value.as_str() == Some(value)) -} - -/// Find content structure checks in AST (like content[0], content|length) -fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool { - use machinery::ast::Stmt; - - match ast { - Stmt::Template(template) => template - .children - .iter() - .any(|child| find_content_structure_checks_in_ast(child)), - Stmt::ForLoop(for_loop) => for_loop - .body - .iter() - .any(|stmt| find_content_structure_checks_in_ast(stmt)), - Stmt::IfCond(if_cond) => { - // Check if condition has content structure checks - has_content_structure_check_expr(&if_cond.expr) - || if_cond - .true_body - .iter() - .any(|stmt| find_content_structure_checks_in_ast(stmt)) - || if_cond - .false_body - .iter() - .any(|stmt| find_content_structure_checks_in_ast(stmt)) - } - Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr), - _ => false, - } -} - -/// Find variable assignment patterns like set content = message['content'] -fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool { - use machinery::ast::Stmt; - - match ast { - Stmt::Template(template) => template - .children - .iter() - .any(|child| find_variable_assignment_patterns_in_ast(child)), - Stmt::ForLoop(for_loop) => { - // Check if this for-loop body contains both assignment and iteration - let has_assignment = for_loop - .body - .iter() - .any(|stmt| is_content_assignment_stmt(stmt)); - let has_iteration = for_loop.body.iter().any(|stmt| { - is_content_variable_iteration(stmt) - || matches!(stmt, Stmt::IfCond(if_cond) if - if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) || - if_cond.false_body.iter().any(|s| is_content_variable_iteration(s)) - ) - }); - - (has_assignment && has_iteration) - || for_loop - .body - .iter() - .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) - } - Stmt::IfCond(if_cond) => { - if_cond - .true_body - .iter() - .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) - || if_cond - .false_body - .iter() - .any(|stmt| find_variable_assignment_patterns_in_ast(stmt)) - } - _ => false, - } -} - -/// Check if expression has content structure checks (index access, length, etc.) -fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool { - use machinery::ast::Expr; - - match expr { - // Check for content[0] - index access - Expr::GetItem(getitem) => { - is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr) - } - // Check for content|length - filter with length - Expr::Filter(filter) => { - if let Some(ref filter_expr) = filter.expr { - is_content_access(filter_expr) && filter.name == "length" - } else { - false - } - } - // Check for content is sequence/iterable - Expr::Test(test) => { - is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable") - } - _ => false, - } -} - -/// Check if statement assigns message content to a variable -fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool { - use machinery::ast::Stmt; - - match stmt { - Stmt::Set(set_stmt) => { - // Check if this is setting content = message['content'] - is_var_access(&set_stmt.target, "content") - && is_var_or_elems_access(&set_stmt.expr, "message", "content") - } - _ => false, - } -} - -/// Check if statement iterates over content variable -fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool { - use machinery::ast::{Expr, Stmt}; - - match stmt { - Stmt::ForLoop(for_loop) => { - // Check if iterating over a variable named "content" - matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content") - } - _ => false, - } -} - -/// Check if expression accesses content (message.content, message['content'], etc.) -fn is_content_access(expr: &machinery::ast::Expr) -> bool { - is_var_or_elems_access(expr, "message", "content") - || is_var_or_elems_access(expr, "msg", "content") - || is_var_or_elems_access(expr, "m", "content") -} - -/// Check if expression is a numeric constant (for index access) -fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool { - matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number()) + ChatTemplateContentFormat::String + }) } /// Parameters for chat template application diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs index 7a1ffa0fa..145cb8227 100644 --- a/sgl-router/tests/chat_template_format_detection.rs +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -249,3 +249,65 @@ fn test_chat_template_with_tokens_unit_test() { assert!(result.contains("")); assert!(result.contains("")); } + +#[test] +fn test_detect_openai_format_qwen3vl_macro_style() { + // Qwen3-VL style template using macros to handle multimodal content + // This tests the macro-based detection pattern + let template = r#"{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count) %} + {%- if content is string %} + {{- content }} + {%- else %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%} + <|vision_start|><|image_pad|><|vision_end|> + {%- elif 'video' in item or item.type == 'video' %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%} + <|vision_start|><|video_pad|><|vision_end|> + {%- elif 'text' in item %} + {{- item.text }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endmacro %} +{%- for message in messages %} + {%- set content = render_content(message.content, True) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %}"#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +} + +#[test] +fn test_detect_openai_format_arbitrary_variable_names() { + // Test that detection works with any variable name, not just "message", "msg", "m" + // Uses "chat_msg" and "x" as loop variables + let template = r#" + {%- for chat_msg in messages %} + {%- for x in chat_msg.content %} + {%- if x.type == 'text' %}{{ x.text }}{%- endif %} + {%- if x.type == 'image' %}{%- endif %} + {%- endfor %} + {%- endfor %} + "#; + + assert_eq!( + detect_chat_template_content_format(template), + ChatTemplateContentFormat::OpenAI + ); +}