Files
sglang/python/sglang/srt/entrypoints/openai/utils.py

208 lines
7.6 KiB
Python

import logging
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs
logger = logging.getLogger(__name__)
# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None
def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)
Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"
try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"
def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.
Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}
if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")
if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)
new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently
new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs