bugfix(OAI): Fix image_data processing for jinja chat templates (#6877)
This commit is contained in:
@@ -75,6 +75,10 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
TopLogprob,
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.openai_api.utils import (
|
||||||
|
detect_template_content_format,
|
||||||
|
process_content_for_template_format,
|
||||||
|
)
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.utils import convert_json_schema_to_str, get_exception_traceback
|
from sglang.utils import convert_json_schema_to_str, get_exception_traceback
|
||||||
|
|
||||||
@@ -82,6 +86,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
chat_template_name = None
|
chat_template_name = None
|
||||||
|
|
||||||
|
# Global cache for template content format detection (one model/template per instance)
|
||||||
|
# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
|
||||||
|
_cached_chat_template = None
|
||||||
|
_cached_template_format = None
|
||||||
|
|
||||||
|
|
||||||
class FileMetadata:
|
class FileMetadata:
|
||||||
def __init__(self, filename: str, purpose: str):
|
def __init__(self, filename: str, purpose: str):
|
||||||
@@ -1000,23 +1009,42 @@ def v1_chat_generate_request(
|
|||||||
|
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
openai_compatible_messages = []
|
openai_compatible_messages = []
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
# Detect template content format by analyzing the jinja template (cached globally)
|
||||||
|
global _cached_chat_template, _cached_template_format
|
||||||
|
current_template = tokenizer_manager.tokenizer.chat_template
|
||||||
|
|
||||||
|
if current_template != _cached_chat_template:
|
||||||
|
# Template changed or first time - analyze it
|
||||||
|
_cached_chat_template = current_template
|
||||||
|
_cached_template_format = detect_template_content_format(
|
||||||
|
current_template
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Detected chat template content format: {_cached_template_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
template_content_format = _cached_template_format
|
||||||
|
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
if message.content is None:
|
if message.content is None:
|
||||||
message.content = ""
|
message.content = ""
|
||||||
msg_dict = message.dict()
|
msg_dict = message.model_dump()
|
||||||
if isinstance(msg_dict.get("content"), list):
|
|
||||||
for chunk in msg_dict["content"]:
|
# Process content based on detected template format
|
||||||
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
processed_msg = process_content_for_template_format(
|
||||||
new_msg = msg_dict.copy()
|
msg_dict,
|
||||||
new_msg["content"] = chunk["text"]
|
template_content_format,
|
||||||
new_msg = {
|
image_data,
|
||||||
k: v for k, v in new_msg.items() if v is not None
|
audio_data,
|
||||||
}
|
modalities,
|
||||||
openai_compatible_messages.append(new_msg)
|
)
|
||||||
else:
|
openai_compatible_messages.append(processed_msg)
|
||||||
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
|
|
||||||
openai_compatible_messages.append(msg_dict)
|
# Handle assistant prefix for continue_final_message
|
||||||
if (
|
if (
|
||||||
openai_compatible_messages
|
openai_compatible_messages
|
||||||
and openai_compatible_messages[-1]["role"] == "assistant"
|
and openai_compatible_messages[-1]["role"] == "assistant"
|
||||||
@@ -1070,9 +1098,9 @@ def v1_chat_generate_request(
|
|||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
|
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = image_data if image_data else None
|
||||||
audio_data = None
|
audio_data = audio_data if audio_data else None
|
||||||
modalities = []
|
modalities = modalities if modalities else []
|
||||||
else:
|
else:
|
||||||
conv = generate_chat_conv(request, chat_template_name)
|
conv = generate_chat_conv(request, chat_template_name)
|
||||||
# If we should continue the final assistant message, adjust the conversation.
|
# If we should continue the final assistant message, adjust the conversation.
|
||||||
|
|||||||
172
python/sglang/srt/openai_api/utils.py
Normal file
172
python/sglang/srt/openai_api/utils.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for OpenAI API adapter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import jinja2.nodes
|
||||||
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# JINJA TEMPLATE CONTENT FORMAT DETECTION
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# This adapts vLLM's approach for detecting chat template content format:
|
||||||
|
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
|
||||||
|
# - Analyzes Jinja template AST to detect content iteration patterns
|
||||||
|
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
|
||||||
|
# - 'string' format: templates that expect simple string content
|
||||||
|
# - Processes content accordingly to match template expectations
|
||||||
|
|
||||||
|
|
||||||
|
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||||
|
"""Check if node is a variable access like {{ varname }}"""
|
||||||
|
if isinstance(node, jinja2.nodes.Name):
|
||||||
|
return node.ctx == "load" and node.name == varname
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||||
|
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
|
||||||
|
if isinstance(node, jinja2.nodes.Getitem):
|
||||||
|
return (
|
||||||
|
_is_var_access(node.node, varname)
|
||||||
|
and isinstance(node.arg, jinja2.nodes.Const)
|
||||||
|
and node.arg.value == key
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(node, jinja2.nodes.Getattr):
|
||||||
|
return _is_var_access(node.node, varname) and node.attr == key
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_var_or_elems_access(
|
||||||
|
node: jinja2.nodes.Node,
|
||||||
|
varname: str,
|
||||||
|
key: str = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if node accesses varname or varname[key] with filters/tests"""
|
||||||
|
if isinstance(node, jinja2.nodes.Filter):
|
||||||
|
return node.node is not None and _is_var_or_elems_access(
|
||||||
|
node.node, varname, key
|
||||||
|
)
|
||||||
|
if isinstance(node, jinja2.nodes.Test):
|
||||||
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
|
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||||
|
node.arg, jinja2.nodes.Slice
|
||||||
|
):
|
||||||
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
|
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_extract_ast(chat_template: str):
|
||||||
|
"""Try to parse the Jinja template into an AST"""
|
||||||
|
try:
|
||||||
|
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||||
|
return jinja_compiled.environment.parse(chat_template)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error when compiling Jinja template: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def detect_template_content_format(chat_template: str) -> str:
|
||||||
|
"""
|
||||||
|
Detect whether a chat template expects 'string' or 'openai' content format.
|
||||||
|
|
||||||
|
- 'string': content is a simple string (like DeepSeek templates)
|
||||||
|
- 'openai': content is a list of structured dicts (like Llama4 templates)
|
||||||
|
|
||||||
|
Detection logic:
|
||||||
|
- If template has loops like {%- for content in message['content'] -%} → 'openai'
|
||||||
|
- Otherwise → 'string'
|
||||||
|
"""
|
||||||
|
jinja_ast = _try_extract_ast(chat_template)
|
||||||
|
if jinja_ast is None:
|
||||||
|
return "string"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Look for patterns like: {%- for content in message['content'] -%}
|
||||||
|
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
|
||||||
|
loop_iter = loop_ast.iter
|
||||||
|
|
||||||
|
# Check if iterating over message['content'] or similar
|
||||||
|
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
||||||
|
return "openai" # Found content iteration → openai format
|
||||||
|
|
||||||
|
return "string" # No content loops found → string format
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
||||||
|
return "string"
|
||||||
|
|
||||||
|
|
||||||
|
def process_content_for_template_format(
|
||||||
|
msg_dict: dict,
|
||||||
|
content_format: str,
|
||||||
|
image_data: list,
|
||||||
|
audio_data: list,
|
||||||
|
modalities: list,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Process message content based on detected template format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_dict: Message dictionary with content
|
||||||
|
content_format: 'string' or 'openai' (detected via AST analysis)
|
||||||
|
image_data: List to append extracted image URLs
|
||||||
|
audio_data: List to append extracted audio URLs
|
||||||
|
modalities: List to append modalities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed message dictionary
|
||||||
|
"""
|
||||||
|
if not isinstance(msg_dict.get("content"), list):
|
||||||
|
# Already a string or None, no processing needed
|
||||||
|
return {k: v for k, v in msg_dict.items() if v is not None}
|
||||||
|
|
||||||
|
if content_format == "openai":
|
||||||
|
# OpenAI format: preserve structured content list, normalize types
|
||||||
|
processed_content_parts = []
|
||||||
|
for chunk in msg_dict["content"]:
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
chunk_type = chunk.get("type")
|
||||||
|
|
||||||
|
if chunk_type == "image_url":
|
||||||
|
image_data.append(chunk["image_url"]["url"])
|
||||||
|
if chunk.get("modalities"):
|
||||||
|
modalities.append(chunk.get("modalities"))
|
||||||
|
# Normalize to simple 'image' type for template compatibility
|
||||||
|
processed_content_parts.append({"type": "image"})
|
||||||
|
elif chunk_type == "audio_url":
|
||||||
|
audio_data.append(chunk["audio_url"]["url"])
|
||||||
|
# Normalize to simple 'audio' type
|
||||||
|
processed_content_parts.append({"type": "audio"})
|
||||||
|
else:
|
||||||
|
# Keep other content as-is (text, etc.)
|
||||||
|
processed_content_parts.append(chunk)
|
||||||
|
|
||||||
|
new_msg = {
|
||||||
|
k: v for k, v in msg_dict.items() if v is not None and k != "content"
|
||||||
|
}
|
||||||
|
new_msg["content"] = processed_content_parts
|
||||||
|
return new_msg
|
||||||
|
|
||||||
|
else: # content_format == "string"
|
||||||
|
# String format: flatten to text only (for templates like DeepSeek)
|
||||||
|
text_parts = []
|
||||||
|
for chunk in msg_dict["content"]:
|
||||||
|
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
||||||
|
text_parts.append(chunk["text"])
|
||||||
|
# Note: For string format, we ignore images/audio since the template
|
||||||
|
# doesn't expect structured content - multimodal placeholders would
|
||||||
|
# need to be inserted differently
|
||||||
|
|
||||||
|
new_msg = msg_dict.copy()
|
||||||
|
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
||||||
|
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
||||||
|
return new_msg
|
||||||
@@ -56,6 +56,7 @@ suites = {
|
|||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 234),
|
TestFile("test_no_overlap_scheduler.py", 234),
|
||||||
|
TestFile("test_openai_adapter.py", 1),
|
||||||
TestFile("test_openai_function_calling.py", 60),
|
TestFile("test_openai_function_calling.py", 60),
|
||||||
TestFile("test_openai_server.py", 149),
|
TestFile("test_openai_server.py", 149),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
|
|||||||
225
test/srt/test_openai_adapter.py
Normal file
225
test/srt/test_openai_adapter.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for OpenAI adapter utils.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from sglang.srt.openai_api.utils import (
|
||||||
|
detect_template_content_format,
|
||||||
|
process_content_for_template_format,
|
||||||
|
)
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestTemplateContentFormatDetection(CustomTestCase):
|
||||||
|
"""Test template content format detection functionality."""
|
||||||
|
|
||||||
|
def test_detect_llama4_openai_format(self):
|
||||||
|
"""Test detection of llama4-style template (should be 'openai' format)."""
|
||||||
|
llama4_pattern = """
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{{- message['content'] }}
|
||||||
|
{%- else %}
|
||||||
|
{%- for content in message['content'] %}
|
||||||
|
{%- if content['type'] == 'image' %}
|
||||||
|
{{- '<|image|>' }}
|
||||||
|
{%- elif content['type'] == 'text' %}
|
||||||
|
{{- content['text'] | trim }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = detect_template_content_format(llama4_pattern)
|
||||||
|
self.assertEqual(result, "openai")
|
||||||
|
|
||||||
|
def test_detect_deepseek_string_format(self):
|
||||||
|
"""Test detection of deepseek-style template (should be 'string' format)."""
|
||||||
|
deepseek_pattern = """
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['role'] == 'user' %}
|
||||||
|
{{- '<|User|>' + message['content'] + '<|Assistant|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = detect_template_content_format(deepseek_pattern)
|
||||||
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
|
def test_detect_invalid_template(self):
|
||||||
|
"""Test handling of invalid template (should default to 'string')."""
|
||||||
|
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
|
||||||
|
|
||||||
|
result = detect_template_content_format(invalid_pattern)
|
||||||
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
|
def test_detect_empty_template(self):
|
||||||
|
"""Test handling of empty template (should default to 'string')."""
|
||||||
|
result = detect_template_content_format("")
|
||||||
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
|
def test_process_content_openai_format(self):
|
||||||
|
"""Test content processing for openai format."""
|
||||||
|
msg_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Look at this image:"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "http://example.com/image.jpg"},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "openai", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that image_data was extracted
|
||||||
|
self.assertEqual(len(image_data), 1)
|
||||||
|
self.assertEqual(image_data[0], "http://example.com/image.jpg")
|
||||||
|
|
||||||
|
# Check that content was normalized
|
||||||
|
expected_content = [
|
||||||
|
{"type": "text", "text": "Look at this image:"},
|
||||||
|
{"type": "image"}, # normalized from image_url
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
]
|
||||||
|
self.assertEqual(result["content"], expected_content)
|
||||||
|
self.assertEqual(result["role"], "user")
|
||||||
|
|
||||||
|
def test_process_content_string_format(self):
|
||||||
|
"""Test content processing for string format."""
|
||||||
|
msg_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "http://example.com/image.jpg"},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "world"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "string", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# For string format, should flatten to text only
|
||||||
|
self.assertEqual(result["content"], "Hello world")
|
||||||
|
self.assertEqual(result["role"], "user")
|
||||||
|
|
||||||
|
# Image data should not be extracted for string format
|
||||||
|
self.assertEqual(len(image_data), 0)
|
||||||
|
|
||||||
|
def test_process_content_with_audio(self):
|
||||||
|
"""Test content processing with audio content."""
|
||||||
|
msg_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Listen to this:"},
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {"url": "http://example.com/audio.mp3"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "openai", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that audio_data was extracted
|
||||||
|
self.assertEqual(len(audio_data), 1)
|
||||||
|
self.assertEqual(audio_data[0], "http://example.com/audio.mp3")
|
||||||
|
|
||||||
|
# Check that content was normalized
|
||||||
|
expected_content = [
|
||||||
|
{"type": "text", "text": "Listen to this:"},
|
||||||
|
{"type": "audio"}, # normalized from audio_url
|
||||||
|
]
|
||||||
|
self.assertEqual(result["content"], expected_content)
|
||||||
|
|
||||||
|
def test_process_content_already_string(self):
|
||||||
|
"""Test processing content that's already a string."""
|
||||||
|
msg_dict = {"role": "user", "content": "Hello world"}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "openai", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should pass through unchanged
|
||||||
|
self.assertEqual(result["content"], "Hello world")
|
||||||
|
self.assertEqual(result["role"], "user")
|
||||||
|
self.assertEqual(len(image_data), 0)
|
||||||
|
|
||||||
|
def test_process_content_with_modalities(self):
|
||||||
|
"""Test content processing with modalities field."""
|
||||||
|
msg_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "http://example.com/image.jpg"},
|
||||||
|
"modalities": ["vision"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "openai", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that modalities was extracted
|
||||||
|
self.assertEqual(len(modalities), 1)
|
||||||
|
self.assertEqual(modalities[0], ["vision"])
|
||||||
|
|
||||||
|
def test_process_content_filter_none_values(self):
|
||||||
|
"""Test that None values are filtered out of processed messages."""
|
||||||
|
msg_dict = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello",
|
||||||
|
"name": None,
|
||||||
|
"tool_call_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = []
|
||||||
|
audio_data = []
|
||||||
|
modalities = []
|
||||||
|
|
||||||
|
result = process_content_for_template_format(
|
||||||
|
msg_dict, "string", image_data, audio_data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# None values should be filtered out
|
||||||
|
expected_keys = {"role", "content"}
|
||||||
|
self.assertEqual(set(result.keys()), expected_keys)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user