From 8b2474898b6c5edd98e8b46504d01772a6819e6b Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 5 Jun 2025 13:37:01 -0700 Subject: [PATCH] bugfix(OAI): Fix image_data processing for jinja chat templates (#6877) --- python/sglang/srt/openai_api/adapter.py | 60 +++++-- python/sglang/srt/openai_api/utils.py | 172 ++++++++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_openai_adapter.py | 225 ++++++++++++++++++++++++ 4 files changed, 442 insertions(+), 16 deletions(-) create mode 100644 python/sglang/srt/openai_api/utils.py create mode 100644 test/srt/test_openai_adapter.py diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 0c8a9a972..5158febdc 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -75,6 +75,10 @@ from sglang.srt.openai_api.protocol import ( TopLogprob, UsageInfo, ) +from sglang.srt.openai_api.utils import ( + detect_template_content_format, + process_content_for_template_format, +) from sglang.srt.reasoning_parser import ReasoningParser from sglang.utils import convert_json_schema_to_str, get_exception_traceback @@ -82,6 +86,11 @@ logger = logging.getLogger(__name__) chat_template_name = None +# Global cache for template content format detection (one model/template per instance) +# NOTE: A better approach would be to initialize the chat template format when the endpoint is created +_cached_chat_template = None +_cached_template_format = None + class FileMetadata: def __init__(self, filename: str, purpose: str): @@ -1000,23 +1009,42 @@ def v1_chat_generate_request( if chat_template_name is None: openai_compatible_messages = [] + image_data = [] + audio_data = [] + modalities = [] + + # Detect template content format by analyzing the jinja template (cached globally) + global _cached_chat_template, _cached_template_format + current_template = tokenizer_manager.tokenizer.chat_template + + if current_template != _cached_chat_template: + # Template changed or first time - analyze it + _cached_chat_template = current_template + _cached_template_format = detect_template_content_format( + current_template + ) + logger.info( + f"Detected chat template content format: {_cached_template_format}" + ) + + template_content_format = _cached_template_format for message in request.messages: if message.content is None: message.content = "" - msg_dict = message.dict() - if isinstance(msg_dict.get("content"), list): - for chunk in msg_dict["content"]: - if isinstance(chunk, dict) and chunk.get("type") == "text": - new_msg = msg_dict.copy() - new_msg["content"] = chunk["text"] - new_msg = { - k: v for k, v in new_msg.items() if v is not None - } - openai_compatible_messages.append(new_msg) - else: - msg_dict = {k: v for k, v in msg_dict.items() if v is not None} - openai_compatible_messages.append(msg_dict) + msg_dict = message.model_dump() + + # Process content based on detected template format + processed_msg = process_content_for_template_format( + msg_dict, + template_content_format, + image_data, + audio_data, + modalities, + ) + openai_compatible_messages.append(processed_msg) + + # Handle assistant prefix for continue_final_message if ( openai_compatible_messages and openai_compatible_messages[-1]["role"] == "assistant" @@ -1070,9 +1098,9 @@ def v1_chat_generate_request( if is_multimodal: prompt = tokenizer_manager.tokenizer.decode(prompt_ids) stop = request.stop - image_data = None - audio_data = None - modalities = [] + image_data = image_data if image_data else None + audio_data = audio_data if audio_data else None + modalities = modalities if modalities else [] else: conv = generate_chat_conv(request, chat_template_name) # If we should continue the final assistant message, adjust the conversation. diff --git a/python/sglang/srt/openai_api/utils.py b/python/sglang/srt/openai_api/utils.py new file mode 100644 index 000000000..610251aff --- /dev/null +++ b/python/sglang/srt/openai_api/utils.py @@ -0,0 +1,172 @@ +""" +Utility functions for OpenAI API adapter. +""" + +import logging +from typing import Dict, List + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils + +logger = logging.getLogger(__name__) + +# ============================================================================ +# JINJA TEMPLATE CONTENT FORMAT DETECTION +# ============================================================================ +# +# This adapts vLLM's approach for detecting chat template content format: +# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313 +# - Analyzes Jinja template AST to detect content iteration patterns +# - 'openai' format: templates with {%- for content in message['content'] -%} loops +# - 'string' format: templates that expect simple string content +# - Processes content accordingly to match template expectations + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + """Check if node is a variable access like {{ varname }}""" + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}""" + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str = None, +) -> bool: + """Check if node accesses varname or varname[key] with filters/tests""" + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _try_extract_ast(chat_template: str): + """Try to parse the Jinja template into an AST""" + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception as e: + logger.debug(f"Error when compiling Jinja template: {e}") + return None + + +def detect_template_content_format(chat_template: str) -> str: + """ + Detect whether a chat template expects 'string' or 'openai' content format. + + - 'string': content is a simple string (like DeepSeek templates) + - 'openai': content is a list of structured dicts (like Llama4 templates) + + Detection logic: + - If template has loops like {%- for content in message['content'] -%} → 'openai' + - Otherwise → 'string' + """ + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return "string" + + try: + # Look for patterns like: {%- for content in message['content'] -%} + for loop_ast in jinja_ast.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + + # Check if iterating over message['content'] or similar + if _is_var_or_elems_access(loop_iter, "message", "content"): + return "openai" # Found content iteration → openai format + + return "string" # No content loops found → string format + except Exception as e: + logger.debug(f"Error when parsing AST of Jinja template: {e}") + return "string" + + +def process_content_for_template_format( + msg_dict: dict, + content_format: str, + image_data: list, + audio_data: list, + modalities: list, +) -> dict: + """ + Process message content based on detected template format. + + Args: + msg_dict: Message dictionary with content + content_format: 'string' or 'openai' (detected via AST analysis) + image_data: List to append extracted image URLs + audio_data: List to append extracted audio URLs + modalities: List to append modalities + + Returns: + Processed message dictionary + """ + if not isinstance(msg_dict.get("content"), list): + # Already a string or None, no processing needed + return {k: v for k, v in msg_dict.items() if v is not None} + + if content_format == "openai": + # OpenAI format: preserve structured content list, normalize types + processed_content_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict): + chunk_type = chunk.get("type") + + if chunk_type == "image_url": + image_data.append(chunk["image_url"]["url"]) + if chunk.get("modalities"): + modalities.append(chunk.get("modalities")) + # Normalize to simple 'image' type for template compatibility + processed_content_parts.append({"type": "image"}) + elif chunk_type == "audio_url": + audio_data.append(chunk["audio_url"]["url"]) + # Normalize to simple 'audio' type + processed_content_parts.append({"type": "audio"}) + else: + # Keep other content as-is (text, etc.) + processed_content_parts.append(chunk) + + new_msg = { + k: v for k, v in msg_dict.items() if v is not None and k != "content" + } + new_msg["content"] = processed_content_parts + return new_msg + + else: # content_format == "string" + # String format: flatten to text only (for templates like DeepSeek) + text_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict) and chunk.get("type") == "text": + text_parts.append(chunk["text"]) + # Note: For string format, we ignore images/audio since the template + # doesn't expect structured content - multimodal placeholders would + # need to be inserted differently + + new_msg = msg_dict.copy() + new_msg["content"] = " ".join(text_parts) if text_parts else "" + new_msg = {k: v for k, v in new_msg.items() if v is not None} + return new_msg diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 197cf3349..83fde313f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -56,6 +56,7 @@ suites = { TestFile("test_mla_fp8.py", 93), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 234), + TestFile("test_openai_adapter.py", 1), TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_server.py", 149), TestFile("test_penalty.py", 41), diff --git a/test/srt/test_openai_adapter.py b/test/srt/test_openai_adapter.py new file mode 100644 index 000000000..598ddfd49 --- /dev/null +++ b/test/srt/test_openai_adapter.py @@ -0,0 +1,225 @@ +""" +Unit tests for OpenAI adapter utils. +""" + +import unittest +from unittest.mock import patch + +from sglang.srt.openai_api.utils import ( + detect_template_content_format, + process_content_for_template_format, +) +from sglang.test.test_utils import CustomTestCase + + +class TestTemplateContentFormatDetection(CustomTestCase): + """Test template content format detection functionality.""" + + def test_detect_llama4_openai_format(self): + """Test detection of llama4-style template (should be 'openai' format).""" + llama4_pattern = """ +{%- for message in messages %} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endfor %} + """ + + result = detect_template_content_format(llama4_pattern) + self.assertEqual(result, "openai") + + def test_detect_deepseek_string_format(self): + """Test detection of deepseek-style template (should be 'string' format).""" + deepseek_pattern = """ +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- '<|User|>' + message['content'] + '<|Assistant|>' }} + {%- endif %} +{%- endfor %} + """ + + result = detect_template_content_format(deepseek_pattern) + self.assertEqual(result, "string") + + def test_detect_invalid_template(self): + """Test handling of invalid template (should default to 'string').""" + invalid_pattern = "{{{{ invalid jinja syntax }}}}" + + result = detect_template_content_format(invalid_pattern) + self.assertEqual(result, "string") + + def test_detect_empty_template(self): + """Test handling of empty template (should default to 'string').""" + result = detect_template_content_format("") + self.assertEqual(result, "string") + + def test_process_content_openai_format(self): + """Test content processing for openai format.""" + msg_dict = { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this image:"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + }, + {"type": "text", "text": "What do you see?"}, + ], + } + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "openai", image_data, audio_data, modalities + ) + + # Check that image_data was extracted + self.assertEqual(len(image_data), 1) + self.assertEqual(image_data[0], "http://example.com/image.jpg") + + # Check that content was normalized + expected_content = [ + {"type": "text", "text": "Look at this image:"}, + {"type": "image"}, # normalized from image_url + {"type": "text", "text": "What do you see?"}, + ] + self.assertEqual(result["content"], expected_content) + self.assertEqual(result["role"], "user") + + def test_process_content_string_format(self): + """Test content processing for string format.""" + msg_dict = { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + }, + {"type": "text", "text": "world"}, + ], + } + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "string", image_data, audio_data, modalities + ) + + # For string format, should flatten to text only + self.assertEqual(result["content"], "Hello world") + self.assertEqual(result["role"], "user") + + # Image data should not be extracted for string format + self.assertEqual(len(image_data), 0) + + def test_process_content_with_audio(self): + """Test content processing with audio content.""" + msg_dict = { + "role": "user", + "content": [ + {"type": "text", "text": "Listen to this:"}, + { + "type": "audio_url", + "audio_url": {"url": "http://example.com/audio.mp3"}, + }, + ], + } + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "openai", image_data, audio_data, modalities + ) + + # Check that audio_data was extracted + self.assertEqual(len(audio_data), 1) + self.assertEqual(audio_data[0], "http://example.com/audio.mp3") + + # Check that content was normalized + expected_content = [ + {"type": "text", "text": "Listen to this:"}, + {"type": "audio"}, # normalized from audio_url + ] + self.assertEqual(result["content"], expected_content) + + def test_process_content_already_string(self): + """Test processing content that's already a string.""" + msg_dict = {"role": "user", "content": "Hello world"} + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "openai", image_data, audio_data, modalities + ) + + # Should pass through unchanged + self.assertEqual(result["content"], "Hello world") + self.assertEqual(result["role"], "user") + self.assertEqual(len(image_data), 0) + + def test_process_content_with_modalities(self): + """Test content processing with modalities field.""" + msg_dict = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + "modalities": ["vision"], + } + ], + } + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "openai", image_data, audio_data, modalities + ) + + # Check that modalities was extracted + self.assertEqual(len(modalities), 1) + self.assertEqual(modalities[0], ["vision"]) + + def test_process_content_filter_none_values(self): + """Test that None values are filtered out of processed messages.""" + msg_dict = { + "role": "user", + "content": "Hello", + "name": None, + "tool_call_id": None, + } + + image_data = [] + audio_data = [] + modalities = [] + + result = process_content_for_template_format( + msg_dict, "string", image_data, audio_data, modalities + ) + + # None values should be filtered out + expected_keys = {"role", "content"} + self.assertEqual(set(result.keys()), expected_keys) + + +if __name__ == "__main__": + unittest.main()