diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md
index 985596292..636bb4f1b 100644
--- a/docs/backend/server_arguments.md
+++ b/docs/backend/server_arguments.md
@@ -148,7 +148,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
-| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
+| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None |
## Data parallelism
diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py
index 49d59b6f7..9c3008572 100644
--- a/python/sglang/srt/configs/__init__.py
+++ b/python/sglang/srt/configs/__init__.py
@@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
+from sglang.srt.configs.step3_vl import (
+ Step3TextConfig,
+ Step3VisionEncoderConfig,
+ Step3VLConfig,
+)
__all__ = [
"ExaoneConfig",
@@ -14,4 +19,7 @@ __all__ = [
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
+ "Step3VLConfig",
+ "Step3TextConfig",
+ "Step3VisionEncoderConfig",
]
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index 37722c492..37fbf07c7 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -335,6 +335,8 @@ class ModelConfig:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
+ # For Step3
+ "num_attention_groups",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
@@ -644,6 +646,7 @@ multimodal_model_archs = [
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
+ "Step3VLForConditionalGeneration",
]
diff --git a/python/sglang/srt/configs/step3_vl.py b/python/sglang/srt/configs/step3_vl.py
new file mode 100644
index 000000000..5519605c6
--- /dev/null
+++ b/python/sglang/srt/configs/step3_vl.py
@@ -0,0 +1,172 @@
+from typing import Any, Optional, Union
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class Step3VisionEncoderConfig(PretrainedConfig):
+ model_type = "step3_vision_encoder"
+
+ def __init__(
+ self,
+ hidden_size=1792,
+ intermediate_size=3072,
+ output_hidden_size=4096,
+ num_hidden_layers=63,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=728,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.output_hidden_size = output_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ super().__init__(**kwargs)
+
+
+class Step3TextConfig(PretrainedConfig):
+ model_type = "step3_text"
+ architectures = ["Step3TextForCausalLM"]
+
+ def __init__(
+ self,
+ hidden_size: int = 7168,
+ intermediate_size: int = 18432,
+ num_attention_heads: int = 64,
+ num_attention_groups: int = 1,
+ num_hidden_layers: int = 61,
+ max_seq_len: int = 65536,
+ vocab_size: int = 128815,
+ rms_norm_eps: float = 1e-5,
+ moe_intermediate_size: int = 5120,
+ moe_num_experts: int = 48,
+ moe_top_k: int = 3,
+ rope_theta: float = 500000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embedding: int = 65536,
+ share_expert_dim: int = 5120,
+ share_q_dim: int = 2048,
+ head_dim: int = 256,
+ norm_expert_weight: bool = False,
+ moe_layers_enum: tuple[int] = (
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ ),
+ **kwargs,
+ ) -> None:
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_groups = num_attention_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.rms_norm_eps = rms_norm_eps
+ self.moe_intermediate_size = moe_intermediate_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.max_position_embedding = max_position_embedding
+ self.share_expert_dim = share_expert_dim
+ self.share_q_dim = share_q_dim
+ self.head_dim = head_dim
+ self.norm_expert_weight = norm_expert_weight
+ self.moe_layers_enum = moe_layers_enum
+
+ super().__init__(**kwargs)
+
+
+class Step3VLConfig(PretrainedConfig):
+ model_type = "step3_vl"
+
+ def __init__(
+ self,
+ vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
+ text_config: Optional[Union[dict, Step3TextConfig]] = None,
+ understand_projector_stride: int = 1,
+ projector_bias: bool = True,
+ image_token_id: int = 128001,
+ **kwargs,
+ ) -> None:
+ if vision_config is None:
+ vision_config = Step3VisionEncoderConfig()
+ elif isinstance(vision_config, dict):
+ vision_config = Step3VisionEncoderConfig(**vision_config)
+ self.vision_config = vision_config
+
+ if text_config is None:
+ text_config = Step3TextConfig()
+ elif isinstance(text_config, dict):
+ text_config = Step3TextConfig(**text_config)
+ self.text_config = text_config
+
+ self.understand_projector_stride = understand_projector_stride
+ self.projector_bias = projector_bias
+ self.hidden_size = text_config.hidden_size
+ self.image_token_id = image_token_id
+
+ super().__init__(**kwargs)
diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py
index 81e406eb7..c34527591 100644
--- a/python/sglang/srt/conversation.py
+++ b/python/sglang/srt/conversation.py
@@ -994,6 +994,23 @@ register_conv_template(
)
)
+register_conv_template(
+ Conversation(
+ name="step3-vl",
+ system_message="<|begin▁of▁sentence|>You are a helpful assistant",
+ system_template="{system_message}\n",
+ roles=(
+ "<|BOT|>user\n",
+ "<|BOT|>assistant\n\n",
+ ),
+ sep="<|EOT|>",
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
+ stop_str="<|EOT|>",
+ image_token="",
+ # add_bos=True,
+ )
+)
+
@register_conv_template_matching_function
def match_internvl(model_path: str):
@@ -1103,3 +1120,9 @@ def match_vila(model_path: str):
def match_mimo_vl(model_path: str):
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
return "mimo-vl"
+
+
+# @register_conv_template_matching_function
+# def match_step3(model_path: str):
+# if re.search(r"step3", model_path, re.IGNORECASE):
+# return "step3-vl"
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index bf6a3d959..6f6403de0 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
+from sglang.srt.function_call.step3_detector import Step3Detector
logger = logging.getLogger(__name__)
@@ -39,6 +40,7 @@ class FunctionCallParser:
"kimi_k2": KimiK2Detector,
"qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
+ "step3": Step3Detector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
diff --git a/python/sglang/srt/function_call/step3_detector.py b/python/sglang/srt/function_call/step3_detector.py
new file mode 100644
index 000000000..b46f4544f
--- /dev/null
+++ b/python/sglang/srt/function_call/step3_detector.py
@@ -0,0 +1,436 @@
+import ast
+import json
+import logging
+import re
+from typing import Any, Dict, List
+
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.ebnf_composer import EBNFComposer
+
+logger = logging.getLogger(__name__)
+
+
+def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
+ """Get the expected type for a function argument from tool schema."""
+ name2tool = {tool.function.name: tool for tool in defined_tools}
+ if func_name not in name2tool:
+ return None
+ tool = name2tool[func_name]
+ parameters = tool.function.parameters or {}
+ properties = parameters.get("properties", {})
+ if arg_key not in properties:
+ return None
+ return properties[arg_key].get("type", None)
+
+
+def parse_arguments(value: str) -> tuple[Any, bool]:
+ """Parse a string value to appropriate type. Returns (parsed_value, success)."""
+ try:
+ try:
+ parsed_value = json.loads(value)
+ except:
+ parsed_value = ast.literal_eval(value)
+ return parsed_value, True
+ except:
+ return value, False
+
+
+class Step3Detector(BaseFormatDetector):
+ """
+ Detector for Step3 model function call format.
+
+ The Step3 format uses special Unicode tokens to delimit function calls
+ with steptml XML format for invocations.
+
+ Format Structure:
+ ```
+ <|tool_calls_begin|>
+ <|tool_call_begin|>function<|tool_sep|>
+ value1
+ value2
+ <|tool_call_end|>
+ <|tool_calls_end|>
+ ```
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "<|tool_calls_begin|>"
+ self.eot_token = "<|tool_calls_end|>"
+ self.tool_call_begin = "<|tool_call_begin|>"
+ self.tool_call_end = "<|tool_call_end|>"
+ self.tool_sep = "<|tool_sep|>"
+
+ # Regex for parsing steptml invocations
+ self.invoke_regex = re.compile(
+ r'(.+?)', re.DOTALL
+ )
+ self.param_regex = re.compile(
+ r'([^<]*)', re.DOTALL
+ )
+
+ # Streaming state variables
+ self._in_tool_block: bool = False
+ self._tool_block_finished: bool = False
+ self._current_function_name: str = ""
+ self._current_parameters: Dict[str, Any] = {}
+ self._in_tool_call: bool = False
+ self._function_name_sent: bool = False
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Step3 format tool call."""
+ return self.bot_token in text
+
+ def _parse_steptml_invoke(
+ self, text: str, tools: List[Tool] = None
+ ) -> tuple[str, dict]:
+ """Parse steptml invoke format to extract function name and parameters."""
+ invoke_match = self.invoke_regex.search(text)
+ if not invoke_match:
+ return None, {}
+
+ func_name = invoke_match.group(1)
+ params_text = invoke_match.group(2)
+
+ params = {}
+ for param_match in self.param_regex.finditer(params_text):
+ param_name = param_match.group(1)
+ param_value = param_match.group(2).strip()
+
+ # If tools provided, use schema-aware parsing
+ if tools:
+ arg_type = get_argument_type(func_name, param_name, tools)
+ if arg_type and arg_type != "string":
+ parsed_value, _ = parse_arguments(param_value)
+ params[param_name] = parsed_value
+ else:
+ params[param_name] = param_value
+ else:
+ # Fallback to generic parsing if no tools provided
+ parsed_value, _ = parse_arguments(param_value)
+ params[param_name] = parsed_value
+
+ return func_name, params
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+ """
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ try:
+ pre_text, rest = text.split(self.bot_token, 1)
+
+ # If no end token, return everything as normal text
+ if self.eot_token not in rest:
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ tool_section, post_text = rest.split(self.eot_token, 1)
+
+ # Find all individual tool calls using regex
+ calls = []
+ tool_call_pattern = (
+ f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
+ )
+
+ for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
+ call_content = match.group(1)
+
+ # Check if it's a function call
+ if self.tool_sep not in call_content:
+ continue
+
+ type_part, invoke_part = call_content.split(self.tool_sep, 1)
+ if type_part.strip() != "function":
+ continue
+
+ func_name, params = self._parse_steptml_invoke(invoke_part, tools)
+ if func_name:
+ # Use parse_base_json to create the ToolCallItem
+ action = {"name": func_name, "arguments": params}
+ calls.extend(self.parse_base_json(action, tools))
+
+ # Combine pre and post text
+ normal_text = pre_text + post_text
+
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # Return the original text if parsing fails
+ return StreamingParseResult(normal_text=text)
+
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for Step3 format.
+ """
+ self._buffer += new_text
+
+ # Build tool indices for validation
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ # If we've finished the tool block, everything is normal text
+ if self._tool_block_finished:
+ normal_text = self._buffer
+ self._buffer = ""
+ return StreamingParseResult(normal_text=normal_text)
+
+ # Check if tool block hasn't started yet
+ if not self._in_tool_block:
+ if self.bot_token in self._buffer:
+ idx = self._buffer.find(self.bot_token)
+ normal_text = self._buffer[:idx]
+ self._buffer = self._buffer[idx + len(self.bot_token) :]
+ self._in_tool_block = True
+ return StreamingParseResult(normal_text=normal_text)
+ else:
+ # Check if we might have a partial bot_token
+ partial_len = self._ends_with_partial_token(
+ self._buffer, self.bot_token
+ )
+ if partial_len:
+ return StreamingParseResult() # Wait for more text
+ else:
+ normal_text = self._buffer
+ self._buffer = ""
+ return StreamingParseResult(normal_text=normal_text)
+
+ # We're inside the tool block
+ calls: List[ToolCallItem] = []
+
+ # Check if tool block is ending
+ if self.eot_token in self._buffer:
+ idx = self._buffer.find(self.eot_token)
+
+ # If we're in the middle of a tool call, we need to handle it
+ if self._in_tool_call:
+ # The buffer before eot_token might contain the end of the current tool call
+ before_eot = self._buffer[:idx]
+ if self.tool_call_end in before_eot:
+ # Parse this final tool call
+ result = self._parse_partial_tool_call(tools)
+ calls.extend(result.calls)
+ else:
+ # Incomplete tool call - log warning
+ logger.warning("Tool block ended with incomplete tool call")
+
+ remaining = self._buffer[idx + len(self.eot_token) :]
+ self._buffer = ""
+ self._tool_block_finished = True
+
+ # Reset any partial tool call state
+ self._reset_streaming_state()
+
+ return StreamingParseResult(normal_text=remaining, calls=calls)
+
+ # Check if we're in a tool call or need to start one
+ if not self._in_tool_call:
+ if self.tool_call_begin in self._buffer:
+ idx = self._buffer.find(self.tool_call_begin)
+ # Remove any content before tool call begin (shouldn't happen but be safe)
+ self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
+ self._in_tool_call = True
+ self._function_name_sent = False
+ self._current_function_name = ""
+ self._current_parameters = {}
+ # Fall through to parse the partial tool call
+ else:
+ # Wait for tool call to begin
+ return StreamingParseResult()
+
+ # Parse partial tool call
+ if self._in_tool_call:
+ return self._parse_partial_tool_call(tools)
+
+ return StreamingParseResult()
+
+ def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
+ """Parse partial tool call for streaming scenarios."""
+ calls = []
+
+ # Check if we have tool_sep (means we're past the type declaration)
+ if self.tool_sep not in self._buffer:
+ return StreamingParseResult(calls=calls) # Wait for more text
+
+ type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
+ if type_part.strip() != "function":
+ # Invalid tool type, skip this tool call
+ self._reset_streaming_state()
+ return StreamingParseResult(calls=calls)
+
+ # Try to extract function name if not sent yet
+ if not self._function_name_sent:
+ name_match = re.search(r'', invoke_part)
+ if name_match:
+ func_name = name_match.group(1)
+
+ # Validate function name
+ if func_name in self._tool_indices:
+ self._current_function_name = func_name
+ self._function_name_sent = True
+
+ # Initialize tool tracking
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+
+ # Ensure tracking arrays are large enough
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Store tool call info
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+
+ # Send tool name with empty parameters
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ else:
+ # Invalid function name
+ logger.warning(f"Invalid function name: {func_name}")
+ self._reset_streaming_state()
+ return StreamingParseResult(calls=calls)
+ else:
+ # Function name not complete yet
+ return StreamingParseResult(calls=calls)
+
+ # Parse parameters incrementally
+ if self._function_name_sent:
+ # Extract all complete parameters
+ new_params = {}
+ for param_match in self.param_regex.finditer(invoke_part):
+ param_name = param_match.group(1)
+ param_value = param_match.group(2).strip()
+
+ # Use schema-aware parsing
+ arg_type = get_argument_type(
+ self._current_function_name, param_name, tools
+ )
+ if arg_type and arg_type != "string":
+ parsed_value, _ = parse_arguments(param_value)
+ new_params[param_name] = parsed_value
+ else:
+ new_params[param_name] = param_value
+
+ # Check if we have new parameters to stream
+ if new_params != self._current_parameters:
+ # Build the JSON content without the closing brace for streaming
+ if not self._current_parameters:
+ # First parameters - send opening brace and content
+ params_content = json.dumps(new_params, ensure_ascii=False)
+ if len(params_content) > 2: # More than just "{}"
+ # Send everything except the closing brace
+ diff = params_content[:-1]
+ else:
+ diff = "{"
+ else:
+ # Subsequent parameters - calculate the incremental diff
+ old_json = json.dumps(self._current_parameters, ensure_ascii=False)
+ new_json = json.dumps(new_params, ensure_ascii=False)
+
+ # Remove closing braces for comparison
+ old_without_brace = old_json[:-1]
+ new_without_brace = new_json[:-1]
+
+ # The new content should extend the old content
+ if new_without_brace.startswith(old_without_brace):
+ diff = new_without_brace[len(old_without_brace) :]
+ else:
+ # Parameters changed in unexpected way - shouldn't happen in normal streaming
+ diff = ""
+
+ if diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ parameters=diff,
+ )
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += diff
+
+ # Update current state
+ self._current_parameters = new_params
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
+
+ # Check if tool call is complete
+ if self.tool_call_end in self._buffer:
+ # Send closing brace if we've sent any parameters
+ if self.streamed_args_for_tool[self.current_tool_id]:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ parameters="}",
+ )
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += "}"
+
+ # Find the end position
+ end_idx = self._buffer.find(self.tool_call_end)
+ # Remove the processed tool call from buffer
+ self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
+
+ # Reset state for next tool call
+ self._reset_streaming_state()
+ self.current_tool_id += 1
+
+ return StreamingParseResult(calls=calls)
+
+ def _reset_streaming_state(self):
+ """Reset streaming state for the next tool call"""
+ self._in_tool_call = False
+ self._function_name_sent = False
+ self._current_function_name = ""
+ self._current_parameters = {}
+
+ def supports_structural_tag(self) -> bool:
+ """Return True if this detector supports structural tag format."""
+ return False
+
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError()
+
+ def build_ebnf(self, tools: List[Tool]) -> str:
+ """
+ Build EBNF grammar for Step3 tool call format.
+ """
+ # Custom call rule for steptml format
+ call_rule_fmt = (
+ '"function" "<|tool_sep|>" "" '
+ '{arguments_rule} ""'
+ )
+
+ # Custom key-value rule for steptml parameters
+ key_value_rule_fmt = (
+ '"" {valrule} ""'
+ )
+
+ return EBNFComposer.build_ebnf(
+ tools,
+ sequence_start_token=self.bot_token,
+ sequence_end_token=self.eot_token,
+ individual_call_start_token=self.tool_call_begin,
+ individual_call_end_token=self.tool_call_end,
+ tool_call_separator="",
+ function_format="xml",
+ call_rule_fmt=call_rule_fmt,
+ key_value_rule_fmt=key_value_rule_fmt,
+ key_value_separator="",
+ )
diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py
index 7c056acdd..bf16addc5 100644
--- a/python/sglang/srt/hf_transformers_utils.py
+++ b/python/sglang/srt/hf_transformers_utils.py
@@ -41,6 +41,7 @@ from sglang.srt.configs import (
ExaoneConfig,
KimiVLConfig,
MultiModalityConfig,
+ Step3VLConfig,
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
+ Step3VLConfig.model_type: Step3VLConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
diff --git a/python/sglang/srt/jinja_template_utils.py b/python/sglang/srt/jinja_template_utils.py
index 9a944c994..ac55699dc 100644
--- a/python/sglang/srt/jinja_template_utils.py
+++ b/python/sglang/srt/jinja_template_utils.py
@@ -165,7 +165,7 @@ def process_content_for_template_format(
new_msg["content"] = processed_content_parts
return new_msg
- else: # content_format == "string"
+ elif content_format == "string":
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
@@ -179,3 +179,6 @@ def process_content_for_template_format(
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
+
+ else:
+ raise ValueError(f"Invalid content format: {content_format}")
diff --git a/python/sglang/srt/managers/template_manager.py b/python/sglang/srt/managers/template_manager.py
index 4684bf1a0..e340f65f0 100644
--- a/python/sglang/srt/managers/template_manager.py
+++ b/python/sglang/srt/managers/template_manager.py
@@ -53,7 +53,7 @@ class TemplateManager:
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
- self._jinja_template_content_format: Optional[str] = None
+ self._jinja_template_content_format: Optional[str] = "openai"
@property
def chat_template_name(self) -> Optional[str]:
@@ -71,31 +71,60 @@ class TemplateManager:
return self._jinja_template_content_format
def load_chat_template(
- self, tokenizer_manager, chat_template_arg: str, model_path: str
+ self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
- chat_template_arg: Template name or file path
+ chat_template_arg: Template name, file path, or None to auto-detect
model_path: Path to the model
"""
- logger.info(f"Loading chat template: {chat_template_arg}")
+ if chat_template_arg:
+ self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
+ else:
+ # Try HuggingFace template first
+ hf_template = self._resolve_hf_chat_template(tokenizer_manager)
+ if hf_template:
+ self._jinja_template_content_format = (
+ detect_jinja_template_content_format(hf_template)
+ )
+ logger.info(
+ f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
+ )
+ return
- if not chat_template_exists(chat_template_arg):
- if not os.path.exists(chat_template_arg):
- raise RuntimeError(
- f"Chat template {chat_template_arg} is not a built-in template name "
- "or a valid chat template file path."
+ # Fallback to SGLang template guessing
+ self.guess_chat_template_from_model_path(model_path)
+
+ # Set default format if no template was found
+ if self._chat_template_name is None:
+ self._jinja_template_content_format = "string"
+ logger.info(
+ "No chat template found, defaulting to 'string' content format"
)
- if chat_template_arg.endswith(".jinja"):
- self._load_jinja_template(tokenizer_manager, chat_template_arg)
- else:
- self._load_json_chat_template(chat_template_arg)
- else:
+ def _load_explicit_chat_template(
+ self, tokenizer_manager, chat_template_arg: str
+ ) -> None:
+ """Load explicitly specified chat template."""
+ logger.info(f"Loading chat template from argument: {chat_template_arg}")
+
+ if chat_template_exists(chat_template_arg):
self._chat_template_name = chat_template_arg
+ return
+
+ if not os.path.exists(chat_template_arg):
+ raise RuntimeError(
+ f"Chat template {chat_template_arg} is not a built-in template name "
+ "or a valid chat template file path."
+ )
+
+ if chat_template_arg.endswith(".jinja"):
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
+ else:
+ self._load_json_chat_template(chat_template_arg)
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
@@ -146,10 +175,7 @@ class TemplateManager:
completion_template: Optional completion template name/path
"""
# Load chat template
- if chat_template:
- self.load_chat_template(tokenizer_manager, chat_template, model_path)
- else:
- self.guess_chat_template_from_model_path(model_path)
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
# Load completion template
if completion_template:
@@ -166,7 +192,7 @@ class TemplateManager:
chat_template
)
logger.info(
- f"Detected chat template content format: {self._jinja_template_content_format}"
+ f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
@@ -224,3 +250,20 @@ class TemplateManager:
override=True,
)
self._completion_template_name = template["name"]
+
+ def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]:
+ """
+ Resolve HuggingFace chat template.
+
+ Returns the chat template string if found, None otherwise.
+ """
+ tokenizer = tokenizer_manager.tokenizer
+
+ # Try to get AutoTokenizer chat template
+ try:
+ return tokenizer.get_chat_template()
+ except Exception as e:
+ logger.debug(f"Error getting chat template via get_chat_template(): {e}")
+
+ logger.debug("No HuggingFace chat template found")
+ return None
diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py
new file mode 100644
index 000000000..3ed0a153f
--- /dev/null
+++ b/python/sglang/srt/models/step3_vl.py
@@ -0,0 +1,994 @@
+import logging
+import math
+from collections.abc import Iterable
+from math import sqrt
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
+
+import torch
+from torch import nn
+from torch.nn import LayerNorm
+from torch.nn import functional as F
+from transformers import PretrainedConfig
+from transformers.activations import ACT2FN
+
+from sglang.srt.configs.step3_vl import (
+ Step3TextConfig,
+ Step3VisionEncoderConfig,
+ Step3VLConfig,
+)
+from sglang.srt.distributed import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce,
+)
+from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
+from sglang.srt.layers.activation import SiluAndMul
+from sglang.srt.layers.attention.vision import VisionAttention
+from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
+from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
+from sglang.srt.layers.layernorm import RMSNorm
+from sglang.srt.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from sglang.srt.layers.logits_processor import LogitsProcessor
+from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
+from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
+from sglang.srt.layers.moe.topk import TopK
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.layers.rotary_embedding import get_rope
+from sglang.srt.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from sglang.srt.managers.mm_utils import (
+ MultiModalityDataPaddingPatternMultimodalTokens,
+ general_mm_embed_routine,
+)
+from sglang.srt.managers.schedule_batch import (
+ Modality,
+ MultimodalDataItem,
+ MultimodalInputs,
+ global_server_args_dict,
+)
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers
+
+logger = logging.getLogger(__name__)
+
+
+"""
+Text Model
+"""
+
+
+class Step3TextMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("gate_up_proj", prefix),
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("down_proj", prefix),
+ )
+ if hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now."
+ )
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Step3TextMoEMLP(nn.Module):
+ # Native
+ def __init__(
+ self,
+ layer_id: int,
+ config: Step3TextConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.layer_id = layer_id
+ if self.tp_size > config.moe_num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.moe_num_experts}."
+ )
+
+ self.topk = TopK(
+ top_k=config.moe_top_k,
+ renormalize=config.norm_expert_weight,
+ use_grouped_topk=False,
+ )
+
+ self.experts = get_moe_impl_class()(
+ num_experts=config.moe_num_experts,
+ top_k=config.moe_top_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ layer_id=layer_id,
+ quant_config=quant_config,
+ prefix=add_prefix("experts", prefix),
+ )
+
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ output_size=config.moe_num_experts,
+ bias=False,
+ quant_config=None,
+ prefix=add_prefix("gate", prefix),
+ )
+
+ if global_server_args_dict["enable_deepep_moe"]:
+ raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ router_logits, _ = self.gate(hidden_states)
+ topk_output = self.topk(hidden_states, router_logits)
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states, topk_output=topk_output
+ )
+
+ if self.tp_size > 1:
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class Step3TextAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ share_q_dim: int,
+ layer_id: int = 0,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ max_position_embeddings: int = 8192,
+ quant_config: Optional[QuantizationConfig] = None,
+ rms_norm_eps=None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+
+ attn_tp_rank = get_attention_tp_rank()
+ attn_tp_size = get_attention_tp_size()
+
+ self.all_tp_rank = get_tensor_model_parallel_rank()
+ self.total_num_heads = num_heads
+ self.attn_tp_rank = attn_tp_rank
+ self.layer_id = layer_id
+ assert self.total_num_heads % attn_tp_size == 0
+ self.num_heads = self.total_num_heads // attn_tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= attn_tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % attn_tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert attn_tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
+ self.head_dim = head_dim
+ self.q_size = share_q_dim if share_q_dim else head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ self.qkv_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [self.q_size, self.kv_size, self.kv_size],
+ bias=False,
+ quant_config=quant_config,
+ tp_rank=0, # In fact, we need a MergedReplicatedLinear
+ tp_size=1,
+ prefix=add_prefix("qkv_proj", prefix),
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ tp_rank=attn_tp_rank,
+ tp_size=attn_tp_size,
+ reduce_results=False,
+ prefix=add_prefix("o_proj", prefix),
+ )
+
+ self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps)
+
+ self.wq = ColumnParallelLinear(
+ self.q_size,
+ self.head_dim * self.total_num_heads,
+ bias=False,
+ quant_config=quant_config,
+ tp_rank=attn_tp_rank,
+ tp_size=attn_tp_size,
+ prefix=add_prefix("wq", prefix),
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = RadixAttention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ layer_id=layer_id,
+ quant_config=quant_config,
+ prefix=add_prefix("attn", prefix),
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.inter_norm(q.contiguous())
+ q, _ = self.wq(q)
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v, forward_batch)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Step3TextDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Step3TextConfig,
+ layer_id: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ # TODO: support shared experts fusion
+ # self.n_shared_experts = 1
+ # self.num_fused_shared_experts = (
+ # 0
+ # if global_server_args_dict["disable_shared_experts_fusion"]
+ # else self.n_shared_experts
+ # )
+ self.num_fused_shared_experts = 0
+ rms_norm_eps = config.rms_norm_eps
+ self.self_attn = Step3TextAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=1,
+ head_dim=head_dim,
+ share_q_dim=config.share_q_dim,
+ layer_id=layer_id,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ rms_norm_eps=rms_norm_eps,
+ quant_config=quant_config,
+ prefix=add_prefix("self_attn", prefix),
+ )
+
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
+ if moe_layers_enum is not None:
+ moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
+ else:
+ # Default to 1dense.
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
+
+ self.use_moe = False
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ self.layer_id = layer_id
+ self.is_layer_sparse = True if layer_id in moe_layers_idx else False
+ self.is_previous_layer_sparse = (
+ True if layer_id - 1 in moe_layers_idx else False
+ )
+
+ self.layer_scatter_modes = LayerScatterModes.init_new(
+ layer_id=layer_id,
+ num_layers=config.num_hidden_layers,
+ is_layer_sparse=self.is_layer_sparse,
+ is_previous_layer_sparse=self.is_previous_layer_sparse,
+ )
+
+ if not self.is_layer_sparse:
+ self.mlp = Step3TextMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=add_prefix("mlp", prefix),
+ )
+ else:
+ self.use_moe = True
+ if self.num_fused_shared_experts == 0:
+ self.moe = Step3TextMoEMLP(
+ layer_id=layer_id,
+ config=config,
+ quant_config=quant_config,
+ prefix=add_prefix("mlp", prefix),
+ )
+ self.share_expert = Step3TextMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.share_expert_dim,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=add_prefix("share_expert", prefix),
+ )
+ else:
+ self.moe = Step3TextMoEMLP(
+ layer_id=layer_id,
+ config=config,
+ quant_config=quant_config,
+ prefix=add_prefix("mlp", prefix),
+ )
+
+ self.layer_communicator = LayerCommunicator(
+ layer_scatter_modes=self.layer_scatter_modes,
+ input_layernorm=self.input_layernorm,
+ post_attention_layernorm=self.post_attention_layernorm,
+ )
+
+ def moe_mlp_forward(self, hidden_states):
+ if not self.num_fused_shared_experts:
+ h = hidden_states.clone()
+ hidden_states = self.moe(hidden_states)
+ hidden_states += self.share_expert(h)
+ else:
+ hidden_states = self.moe(hidden_states)
+ return hidden_states
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ residual: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ hidden_states, residual = self.layer_communicator.prepare_attn(
+ hidden_states, residual, forward_batch
+ )
+
+ if hidden_states.shape[0] != 0:
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ forward_batch=forward_batch,
+ )
+
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
+ hidden_states, residual, forward_batch
+ )
+ if self.use_moe:
+ hidden_states = self.moe_mlp_forward(hidden_states)
+ else:
+ hidden_states = self.mlp(hidden_states)
+
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
+ hidden_states, residual, forward_batch
+ )
+
+ return hidden_states, residual
+
+
+class Step3TextModel(nn.Module):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
+ prefix=add_prefix("embed_tokens", prefix),
+ )
+
+ self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda idx, prefix: Step3TextDecoderLayer(
+ layer_id=idx,
+ config=config,
+ quant_config=quant_config,
+ prefix=prefix,
+ ),
+ prefix=add_prefix("layers", prefix),
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ input_embeds: torch.Tensor = None,
+ ) -> torch.Tensor:
+ if input_embeds is None:
+ hidden_states = self.embed_tokens(input_ids)
+ else:
+ hidden_states = input_embeds
+
+ residual = None
+ for i in range(len(self.layers)):
+ layer = self.layers[i]
+ hidden_states, residual = layer(
+ positions, hidden_states, forward_batch, residual
+ )
+
+ if hidden_states.shape[0] != 0:
+ if residual is None:
+ hidden_states = self.norm(hidden_states)
+ else:
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+"""
+Vision Model
+"""
+
+
+def get_abs_pos(abs_pos, tgt_size):
+ dim = abs_pos.size(-1)
+ abs_pos_new = abs_pos.squeeze(0)
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
+
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
+ tgt_size = int(math.sqrt(tgt_size))
+ dtype = abs_pos.dtype
+
+ if src_size != tgt_size:
+ old_pos_embed = (
+ old_pos_embed.view(1, src_size, src_size, dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ old_pos_embed = old_pos_embed.to(torch.float32)
+ new_pos_embed = F.interpolate(
+ old_pos_embed,
+ size=(tgt_size, tgt_size),
+ mode="bicubic",
+ antialias=True,
+ align_corners=False,
+ ).to(dtype)
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
+ return vision_pos_embed
+ else:
+ return abs_pos
+
+
+class Step3VisionMLP(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_size: int,
+ bias: bool = True,
+ hidden_act="quick_gelu",
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.fc1 = ColumnParallelLinear(
+ dim,
+ intermediate_size,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=add_prefix("gate_proj", prefix),
+ )
+ self.act = ACT2FN[hidden_act] # quick_gelu
+ self.fc2 = RowParallelLinear(
+ intermediate_size,
+ dim,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=add_prefix("down_proj", prefix),
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Step3VisionAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 16,
+ qkv_backend="fa3",
+ quant_config=None,
+ prefix: str = "",
+ ) -> None:
+
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.out_proj = RowParallelLinear(
+ dim,
+ dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("out_proj", prefix),
+ )
+ self.scale = self.head_dim**-0.5
+
+ self.attn = VisionAttention(
+ embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ use_qkv_parallel=True,
+ rotary_embed="normal",
+ proj_bias=True,
+ qkv_backend=qkv_backend,
+ quant_config=quant_config,
+ prefix=add_prefix("attn", prefix),
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ attn_output = self.attn(hidden_states)
+ return attn_output
+
+
+class Step3VisionEmbeddings(nn.Module):
+
+ def __init__(self, config: Step3VisionEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.pad_tp_size = 4 # hard code for padding
+ # To load the pretrained weights, we still use P+1 as the seqlen
+ self.position_embedding = torch.nn.Embedding(
+ self.num_patches + 1, self.embed_dim
+ )
+ self.register_buffer(
+ "position_ids",
+ torch.arange(self.num_patches + 1).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(
+ pixel_values
+ ) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ # pad
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + get_abs_pos(
+ self.position_embedding(self.position_ids), patch_embeds.size(1)
+ )
+ embeddings = torch.cat(
+ [
+ embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
+ embeddings,
+ ],
+ dim=1,
+ )
+ return embeddings
+
+
+class Step3VisionEncoderLayer(nn.Module):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6)
+ self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6)
+
+ self.self_attn = Step3VisionAttention(
+ self.embed_dim, num_heads=config.num_attention_heads
+ )
+ self.mlp = Step3VisionMLP(
+ dim=self.embed_dim,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
+ hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
+ return hidden_states
+
+
+class Step3VisionTransformer(nn.Module):
+ def __init__(self, config: Step3VisionEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.image_size = config.image_size
+ self.embeddings = Step3VisionEmbeddings(config)
+ self.transformer = Step3VisionEncoder(config)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.embeddings.patch_embedding.weight.dtype
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ ):
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.transformer(inputs_embeds=hidden_states)
+ return hidden_states
+
+
+class Step3VisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Step3VisionEncoderLayer`].
+
+ Args:
+ config: StepVisionEncoderConfig
+ """
+
+ def __init__(self, config: Step3VisionEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ inputs_embeds,
+ ) -> torch.Tensor:
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ )
+
+ return hidden_states
+
+
+class Step3VLForConditionalGeneration(nn.Module):
+
+ def __init__(
+ self,
+ config: Step3VLConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Step3TextModel(
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
+ )
+
+ self.vision_model = Step3VisionTransformer(config.vision_config)
+
+ self.vit_downsampler = nn.Conv2d(
+ config.vision_config.hidden_size,
+ config.vision_config.output_hidden_size,
+ kernel_size=2,
+ stride=config.understand_projector_stride,
+ )
+ self.vit_downsampler2 = nn.Conv2d(
+ config.vision_config.output_hidden_size,
+ config.vision_config.output_hidden_size * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ self.vit_large_projector = nn.Linear(
+ config.vision_config.output_hidden_size * 2,
+ config.hidden_size,
+ bias=config.projector_bias,
+ )
+
+ # TODO: support shared experts fusion
+ # self.n_shared_experts = 1
+ # self.num_fused_shared_experts = (
+ # 0
+ # if global_server_args_dict["disable_shared_experts_fusion"]
+ # else self.n_shared_experts
+ # )
+ self.num_fused_shared_experts = 0
+ self.config.tie_word_embeddings = False
+ if getattr(self.config, "tie_word_embeddings", False):
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(
+ config.text_config.vocab_size,
+ config.text_config.hidden_size,
+ quant_config=quant_config,
+ prefix=add_prefix("lm_head", prefix),
+ )
+ self.logits_processor = LogitsProcessor(config.text_config)
+
+ def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
+ return self.vision_model(input_tensor)[:, 4:]
+
+ def _flatten_embeddings(self, embeddings) -> torch.Tensor:
+
+ if isinstance(embeddings, torch.Tensor):
+ # Flatten all but the last dimension.
+ return embeddings.flatten(0, -2)
+
+ return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings))
+
+ def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
+ B, P = image_features.shape[:2]
+ HW = int(sqrt(P))
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
+ image_features = self.vit_downsampler(image_features)
+ image_features = self.vit_downsampler2(image_features)
+ n_dim = image_features.size(1)
+ image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
+ image_features = self.vit_large_projector(image_features)
+ return image_features
+
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
+ assert len(items) == 1 # We only have images.
+
+ item = items[0]
+ pixel_values = item.feature.type(self.vision_model.dtype)
+ num_patches = item.model_specific_data.get("num_patches")
+ patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None)
+ if patch_pixel_values is not None:
+ patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype)
+
+ if patch_pixel_values is not None:
+ patch_pixel_values = patch_pixel_values.to("cuda")
+
+ image_features = self._get_vision_model_output(pixel_values)
+ patch_image_features = (
+ self._get_vision_model_output(patch_pixel_values)
+ if patch_pixel_values is not None
+ else None
+ )
+
+ image_features = self._process_image_features(image_features)
+ patch_image_features = (
+ self._process_image_features(patch_image_features)
+ if patch_image_features is not None
+ else None
+ )
+
+ merged_image_features = []
+ cur_patch_idx = 0
+ for i, num_patch in enumerate(num_patches):
+ cur_feature = []
+ if num_patch > 0:
+ patch_slice = patch_image_features[
+ cur_patch_idx : cur_patch_idx + num_patch
+ ]
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
+ cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
+ cur_patch_idx += num_patch
+ merged_image_features.append(
+ torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
+ )
+ return self._flatten_embeddings(merged_image_features)
+
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ input_embeds: torch.Tensor = None,
+ ) -> torch.Tensor:
+ hidden_states = general_mm_embed_routine(
+ input_ids=input_ids,
+ forward_batch=forward_batch,
+ language_model=self.model,
+ data_embedding_funcs={
+ Modality.IMAGE: self.get_image_feature,
+ },
+ positions=positions,
+ )
+
+ return self.logits_processor(
+ input_ids, hidden_states, self.lm_head, forward_batch
+ )
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ # TODO:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv_proj", ".q_proj", 0),
+ (".qkv_proj", ".k_proj", 1),
+ (".qkv_proj", ".v_proj", 2),
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+
+ if self.num_fused_shared_experts > 0:
+ assert self.num_fused_shared_experts == 1
+ log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
+
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.text_config.moe_num_experts
+ + self.num_fused_shared_experts,
+ )
+
+ params_dict = dict(self.named_parameters())
+ loaded_params = set()
+
+ def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool:
+ name_parts = name_path.split(".")
+ weight_parts = weight_path.split(".")
+ shard_id_matches = name_parts[4] == weight_parts[2]
+ return shard_id_matches
+
+ for name, loaded_weight in weights:
+ if "vision_model" in name:
+ # 1.It’s not great, but let’s leave it like this for now
+ name = name.replace("self_attn", "self_attn.attn")
+ # 2.
+ name = name.replace("out_proj", "proj")
+
+ # TODO: support vision model
+ if self.num_fused_shared_experts > 0 and "share" in name:
+ # assert False
+ name = name.replace("share_expert", "moe")
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if (
+ expert_id != self.config.text_config.moe_num_experts
+ or not match_expert_and_shard_ids(name, weight_name)
+ ):
+ continue
+
+ part_name = weight_name.split(".")[-2]
+ fake_weight_name = name.replace(part_name, weight_name[:-1])
+ actual_param_name = name.replace(part_name + ".", param_name)
+ param = params_dict[actual_param_name]
+ weight_loader = param.weight_loader
+ weight_loader(
+ param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ )
+ break
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ if "gate." not in name and "moe" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ if "moe" not in name:
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ else:
+ if "gate." in name:
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ continue
+
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if expert_id == self.config.text_config.moe_num_experts:
+ continue
+ if not match_expert_and_shard_ids(name, weight_name):
+ continue
+ part_name = weight_name.split(".")[-2]
+ fake_weight_name = name.replace(part_name, weight_name[:-1])
+ actual_param_name = name.replace(part_name + ".", param_name)
+ param = params_dict[actual_param_name]
+ weight_loader = param.weight_loader
+ weight_loader(
+ param,
+ loaded_weight[expert_id],
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ )
+ loaded_params.add(actual_param_name)
+ # Don't break here, because this 'loaded_weight' includes all the weights for this layer
+
+ @classmethod
+ def get_model_config_for_expert_location(cls, config: Step3VLConfig):
+ return ModelConfigForExpertLocation(
+ num_layers=config.text_config.num_hidden_layers,
+ num_logical_experts=config.text_config.moe_num_experts,
+ num_groups=None,
+ )
+
+
+EntryClass = Step3VLForConditionalGeneration
diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py
index c98720652..06e5c0da0 100644
--- a/python/sglang/srt/multimodal/processors/base_processor.py
+++ b/python/sglang/srt/multimodal/processors/base_processor.py
@@ -176,6 +176,8 @@ class BaseMultimodalProcessor(ABC):
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
+ "num_patches": Modality.IMAGE,
+ "patch_pixel_values": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
diff --git a/python/sglang/srt/multimodal/processors/step3_vl.py b/python/sglang/srt/multimodal/processors/step3_vl.py
new file mode 100644
index 000000000..4ed09635b
--- /dev/null
+++ b/python/sglang/srt/multimodal/processors/step3_vl.py
@@ -0,0 +1,515 @@
+import math
+import re
+from itertools import product
+from typing import List, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+from transformers import BatchFeature, TensorType
+
+from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
+from sglang.srt.multimodal.processors.base_processor import (
+ BaseMultimodalProcessor as SGLangBaseProcessor,
+)
+from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
+
+ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
+
+
+class GPUToTensor(torch.nn.Module):
+
+ def forward(self, raw_image: Union[np.ndarray, Image.Image]) -> torch.Tensor:
+ if isinstance(raw_image, Image.Image):
+ return transforms.ToTensor()(raw_image)
+ if raw_image.ndim == 2:
+ raw_image = raw_image[:, :, None].repeat(3, -1)
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+ image_tensor = torch.from_numpy(raw_image).to(device)
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
+ if image_tensor.dtype == torch.uint8:
+ image_tensor = image_tensor.to(torch.float32).div(255)
+ return image_tensor
+
+
+class Step3VisionProcessor:
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+ patch_size = patch_size if patch_size is not None else size
+
+ self.transform = transforms.Compose(
+ [
+ GPUToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (size, size),
+ interpolation=(
+ InterpolationMode.BICUBIC
+ if interpolation_mode == "bicubic"
+ else InterpolationMode.BILINEAR
+ ),
+ antialias=True,
+ ),
+ ]
+ )
+
+ self.patch_transform = (
+ transforms.Compose(
+ [
+ GPUToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (patch_size, patch_size),
+ interpolation=(
+ InterpolationMode.BICUBIC
+ if interpolation_mode == "bicubic"
+ else InterpolationMode.BILINEAR
+ ),
+ antialias=True,
+ ),
+ ]
+ )
+ if patch_size is not None
+ else None
+ )
+
+ def __call__(self, image, is_patch=False):
+ if is_patch:
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
+ else:
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
+
+
+class ImagePatcher:
+
+ def determine_window_size(self, long: int, short: int) -> int:
+ if long <= 728:
+ return short if long / short > 1.5 else 0
+ return min(short, 504) if long / short > 4 else 504
+
+ def slide_window(
+ self,
+ width: int,
+ height: int,
+ sizes: list[tuple[int, int]],
+ steps: list[tuple[int, int]],
+ img_rate_thr: float = 0.6,
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
+ windows = []
+ # Sliding windows.
+ for size, step in zip(sizes, steps):
+ size_w, size_h = size
+ step_w, step_h = step
+
+ x_num = 1 if width <= size_w else math.ceil((width - size_w) / step_w + 1)
+ x_start = [step_w * i for i in range(x_num)]
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
+ x_start[-1] = width - size_w
+
+ y_num = 1 if height <= size_h else math.ceil((height - size_h) / step_h + 1)
+ y_start = [step_h * i for i in range(y_num)]
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
+ y_start[-1] = height - size_h
+
+ start = np.array(list(product(y_start, x_start)), dtype=int)
+ start[:, [0, 1]] = start[:, [1, 0]]
+ windows.append(np.concatenate([start, start + size], axis=1))
+ windows = np.concatenate(windows, axis=0)
+
+ return [
+ (int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1]))
+ for box in windows
+ ], (x_num, y_num)
+
+ def square_pad(self, img: Image.Image) -> Image.Image:
+ w, h = img.size
+ if w == h:
+ return img
+ size = max(w, h)
+ padded = Image.new(img.mode, (size, size), 0)
+ padded.paste(img, (0, 0))
+ return padded
+
+ def get_image_size_for_padding(
+ self, img_width: int, img_height: int
+ ) -> tuple[int, int]:
+ ratio = img_width / img_height
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
+ new_size = max(img_height, img_width)
+ return new_size, new_size
+ return img_width, img_height
+
+ def get_image_size_for_preprocess(
+ self, img_width: int, img_height: int
+ ) -> tuple[int, int]:
+
+ if max(img_height, img_width) > 3024:
+ scale_factor = 3024 / max(img_height, img_width)
+ img_width = int(img_width * scale_factor)
+ img_height = int(img_height * scale_factor)
+ return img_width, img_height
+ else:
+ return img_width, img_height
+
+ def get_image_size_for_crop(
+ self, img_width: int, img_height: int, window_size: int
+ ):
+ w_ratio = img_width / window_size
+ h_ratio = img_height / window_size
+
+ if w_ratio < 1:
+ width_new = img_width
+ else:
+ decimal_w = w_ratio - img_width // window_size
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
+ width_new = window_size * w_ratio
+ if h_ratio < 1:
+ height_new = img_height
+ else:
+ decimal_h = h_ratio - img_height // window_size
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
+ height_new = window_size * h_ratio
+ return int(width_new), int(height_new)
+
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
+ target = img.crop((j, i, j + tw, i + th))
+ return target
+
+ def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]:
+ img_width, img_height = self.get_image_size_for_padding(img_width, img_height)
+ img_width, img_height = self.get_image_size_for_preprocess(
+ img_width, img_height
+ )
+ window_size = self.determine_window_size(
+ max(img_height, img_width), min(img_height, img_width)
+ )
+ if window_size == 0:
+ return 0, 0
+ else:
+ img_width, img_height = self.get_image_size_for_crop(
+ img_width, img_height, window_size
+ )
+ center_list, (x_num, y_num) = self.slide_window(
+ img_width,
+ img_height,
+ [(window_size, window_size)],
+ [(window_size, window_size)],
+ )
+ full_rows = (len(center_list) - 1) // x_num + 1
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
+ full_rows -= 1
+ return len(center_list), full_rows
+
+ def __call__(
+ self, img: Image.Image
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
+ img_width, img_height = img.size
+ new_img_width, new_img_height = self.get_image_size_for_padding(
+ img_width, img_height
+ )
+ if new_img_width != img_width or new_img_height != img_height:
+ img = self.square_pad(img)
+ img_width, img_height = img.size
+
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
+ img_width, img_height
+ )
+ img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR)
+ window_size = self.determine_window_size(
+ max(new_img_height, new_img_width), min(new_img_height, new_img_width)
+ )
+ if window_size == 0:
+ return img, [], None
+ else:
+ new_img_width, new_img_height = self.get_image_size_for_crop(
+ new_img_width, new_img_height, window_size
+ )
+ if (new_img_width, new_img_height) != (img_width, img_height):
+ img_for_crop = img.resize(
+ (new_img_width, new_img_height), Image.Resampling.BILINEAR
+ )
+ else:
+ img_for_crop = img
+
+ patches = []
+ newlines = []
+ center_list, (x_num, y_num) = self.slide_window(
+ new_img_width,
+ new_img_height,
+ [(window_size, window_size)],
+ [(window_size, window_size)],
+ )
+ for patch_id, center_lf_point in enumerate(center_list):
+ x, y, patch_w, patch_h = center_lf_point
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w)
+ patches.append(big_patch)
+ if (patch_id + 1) % x_num == 0:
+ newlines.append(patch_id)
+
+ if newlines and newlines[-1] == len(patches) - 1:
+ newlines.pop()
+
+ return (
+ img,
+ patches,
+ (
+ [i in newlines for i in range(len(patches))]
+ if len(patches) > 0
+ else None
+ ),
+ )
+
+
+class Step3VLProcessor:
+ def __init__(
+ self,
+ config,
+ tokenizer,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.image_size = 728
+ self.patch_size = 504
+ self.image_preprocessor = Step3VisionProcessor(
+ self.image_size, "bilinear", self.patch_size
+ )
+
+ self.num_image_feature_size = 169
+ self.num_patch_feature_size = 81
+ self.image_token = ""
+ self.image_feature_placeholder = self.image_token * self.num_image_feature_size
+ self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
+
+ self.patcher = ImagePatcher()
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[self.image_token]
+
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
+ num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
+
+ return (
+ num_patches * (self.num_patch_feature_size + 2)
+ + self.num_image_feature_size
+ + 2
+ + num_newlines
+ )
+
+ def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]:
+ result = []
+ for img in images:
+ result.append(self.patcher(img))
+ return result
+
+ def _convert_images_to_pixel_values(
+ self,
+ images: list[Image.Image],
+ is_patch: bool = False,
+ ) -> list[torch.Tensor]:
+ return [
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
+ for img in images
+ ]
+
+ def _get_patch_repl(
+ self,
+ num_patches: int,
+ patch_newline_mask: list[bool] | None,
+ ) -> tuple[str, list[int]]:
+ text = ""
+ token_ids = []
+ for i in range(num_patches):
+ assert len(patch_newline_mask) == num_patches
+ text += f"{self.patch_feature_placeholder}"
+ token_ids.extend(
+ [self.tokenizer.convert_tokens_to_ids("")]
+ + [self.image_token_id] * self.num_patch_feature_size
+ + [self.tokenizer.convert_tokens_to_ids("")]
+ )
+ if patch_newline_mask and patch_newline_mask[i]:
+ text += ""
+ token_ids.append(
+ self.tokenizer.convert_tokens_to_ids("")
+ )
+ return text, token_ids
+
+ def _get_image_repl(
+ self,
+ num_images: int,
+ ) -> tuple[str, list[int]]:
+ text = f"{self.image_feature_placeholder}"
+ token_ids = (
+ [self.tokenizer.convert_tokens_to_ids("")]
+ + [self.image_token_id] * self.num_image_feature_size
+ + [self.tokenizer.convert_tokens_to_ids("")]
+ )
+ return text * num_images, token_ids * num_images
+
+ def _get_image_repl_features(
+ self,
+ num_images: int,
+ num_patches: int,
+ patch_new_line_idx: Optional[list[bool]],
+ ) -> tuple[str, list[int]]:
+ if num_patches > 0:
+ patch_repl, patch_repl_ids = self._get_patch_repl(
+ num_patches, patch_new_line_idx
+ )
+ else:
+ patch_repl = ""
+ patch_repl_ids = []
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
+
+ def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
+ parts = text.split(placeholder)
+
+ if len(parts) - 1 != len(repls):
+ raise ValueError(
+ "The number of placeholders does not match the number of replacements." # noqa: E501
+ )
+
+ result = [parts[0]]
+ for i, repl in enumerate(repls):
+ result.append(repl)
+ result.append(parts[i + 1])
+
+ return "".join(result)
+
+ def __call__(
+ self,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ *args,
+ **kwargs,
+ ) -> BatchFeature:
+ if text is None:
+ text = []
+ if not isinstance(text, list):
+ text = [text]
+ if images is None:
+ images = []
+ if not isinstance(images, list):
+ images = [images]
+
+ if len(images) == 0:
+ image_inputs = {}
+ text_inputs = self.tokenizer(text)
+ else:
+ splitted_images_data = self._split_images(images)
+ pixel_values_lst = []
+ patch_pixel_values_lst = []
+ patch_newline_mask_lst = []
+ image_repl_str_lst = []
+ image_repl_ids_lst = []
+ num_patches = []
+ for (
+ raw_img,
+ img_patches,
+ patch_newline_mask,
+ ) in splitted_images_data: # noqa: E501
+ pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
+
+ if len(img_patches) > 0:
+ patch_pixel_values_lst.extend(
+ self._convert_images_to_pixel_values(img_patches, is_patch=True)
+ )
+ num_patches.append(len(img_patches))
+
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
+ 1, len(img_patches), patch_newline_mask
+ )
+ image_repl_str_lst.append(image_repl_str)
+ image_repl_ids_lst.extend(image_repl_ids)
+
+ if patch_newline_mask is not None:
+ patch_newline_mask_lst.extend(patch_newline_mask)
+
+ image_inputs = {
+ "pixel_values": torch.cat(pixel_values_lst),
+ "num_patches": num_patches,
+ }
+ if patch_pixel_values_lst:
+ image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
+ if patch_newline_mask_lst:
+ image_inputs["patch_newline_mask"] = torch.tensor(
+ patch_newline_mask_lst, dtype=torch.bool
+ )
+
+ text = [
+ self.replace_placeholder(t, self.image_token, image_repl_str_lst)
+ for t in text
+ ]
+ text_inputs = self.tokenizer(text)
+
+ return BatchFeature(
+ {
+ **text_inputs,
+ **image_inputs,
+ },
+ tensor_type=return_tensors,
+ )
+
+
+################################################
+
+
+class Step3VLImageProcessor(SGLangBaseProcessor):
+ models = [Step3VLForConditionalGeneration]
+
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
+ # TODO, check _processor is tokenizer or processor.
+ processor = Step3VLProcessor(hf_config, _processor)
+ super().__init__(hf_config, server_args, processor, *args, **kwargs)
+ self.IM_TOKEN_ID = 128001
+ self.mm_tokens = MultimodalSpecialTokens(
+ image_token="",
+ image_token_id=128001,
+ image_token_regex=re.compile(r"(?:)"),
+ ).build(_processor)
+
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ def preprocess(self, image):
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
+
+ def __call__(self, image):
+ return self.preprocess(image)
+
+ async def process_mm_data_async(
+ self,
+ image_data: List[Union[str, bytes]],
+ input_text: str | List[int],
+ request_obj,
+ *args,
+ **kwargs,
+ ):
+ base_output = self.load_mm_data(
+ prompt=input_text,
+ image_data=image_data,
+ video_data=request_obj.video_data,
+ multimodal_tokens=self.mm_tokens,
+ )
+
+ mm_items, input_ids, ret = self.process_and_combine_mm_data(
+ base_output, self.mm_tokens
+ )
+
+ return {
+ "input_ids": input_ids.tolist(),
+ "mm_items": mm_items,
+ "im_token_id": self.mm_tokens.image_token_id,
+ }
diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py
index b5b737856..a2561a18d 100644
--- a/python/sglang/srt/reasoning_parser.py
+++ b/python/sglang/srt/reasoning_parser.py
@@ -105,7 +105,7 @@ class BaseReasoningFormatDetector:
# If we're not in a reasoning block return as normal text
if not self._in_reasoning:
self._buffer = ""
- return StreamingParseResult(normal_text=new_text)
+ return StreamingParseResult(normal_text=current_text)
return StreamingParseResult()
@@ -233,6 +233,7 @@ class ReasoningParser:
"qwen3-thinking": Qwen3ThinkingDetector,
"glm45": Qwen3Detector,
"kimi": KimiDetector,
+ "step3": DeepSeekR1Detector,
}
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 992905437..9e673a9f4 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -1117,9 +1117,10 @@ class ServerArgs:
"kimi_k2",
"qwen3_coder",
"glm45",
+ "step3",
],
default=ServerArgs.tool_call_parser,
- help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
)
# Data parallelism
diff --git a/test/srt/test_reasoning_parser.py b/test/srt/test_reasoning_parser.py
index 7f3359144..97eea82b4 100644
--- a/test/srt/test_reasoning_parser.py
+++ b/test/srt/test_reasoning_parser.py
@@ -493,5 +493,117 @@ class TestIntegrationScenarios(CustomTestCase):
self.assertIn("final answer", all_normal)
+class TestBufferLossBugFix(CustomTestCase):
+ """Test cases for the buffer loss bug fix in parse_streaming_increment."""
+
+ def test_partial_end_tag_buffer_loss_bug(self):
+ """
+ Test the bug where partial end tag fragments are lost when followed by normal text.
+
+ Bug scenario:
+ 1. _in_reasoning is False
+ 2. new_text is "" (part of closing thinking tag)
+ 3. Fragment is stored in buffer and empty string is returned
+ 4. Next step: new_text is "answer", _in_reasoning still False
+ 5. Buffer is cleared and "answer" is returned directly
+ 6. The "" from previous step is lost
+
+ This test verifies the fix where line 108 was changed from:
+ return StreamingParseResult(normal_text=new_text)
+ to:
+ return StreamingParseResult(normal_text=current_text)
+ """
+ detector = BaseReasoningFormatDetector("", "")
+
+ # Step 1: Send partial end tag when not in reasoning mode
+ # This should be buffered since it could be start of ""
+ result1 = detector.parse_streaming_increment("")
+ self.assertEqual(result1.normal_text, "")
+ self.assertEqual(result1.reasoning_text, "")
+
+ # Step 2: Send normal text that doesn't complete the end tag
+ # Before fix: would return only "answer", losing the ""
+ # After fix: should return the complete buffered content "", "")
+
+ # Send partial start tag
+ result1 = detector.parse_streaming_increment("
", "")
+
+ # Enter reasoning mode
+ detector.parse_streaming_increment("")
+ detector.parse_streaming_increment("some reasoning")
+
+ # Send partial end tag
+ result1 = detector.parse_streaming_increment("")
+ self.assertEqual(result1.normal_text, "")
+ self.assertEqual(result1.reasoning_text, "")
+
+ # Complete the end tag with normal text
+ result2 = detector.parse_streaming_increment("think>normal text")
+ self.assertEqual(result2.normal_text, "normal text")
+ # The reasoning text should be empty since buffer was cleared when end tag was processed
+ self.assertEqual(result2.reasoning_text, "")
+
+ def test_multiple_partial_fragments(self):
+ """
+ Test handling of multiple partial fragments that don't match any tokens.
+ """
+ detector = BaseReasoningFormatDetector("", "")
+
+ # Send multiple partial fragments
+ result1 = detector.parse_streaming_increment("<")
+ self.assertEqual(result1.normal_text, "")
+ self.assertEqual(result1.reasoning_text, "")
+
+ result2 = detector.parse_streaming_increment("/")
+ self.assertEqual(result2.normal_text, "")
+ self.assertEqual(result2.reasoning_text, "")
+
+ result3 = detector.parse_streaming_increment("random>")
+ self.assertEqual(result3.normal_text, "")
+ self.assertEqual(result3.reasoning_text, "")
+
+ def test_edge_case_exact_token_match(self):
+ """
+ Test edge case where buffer content exactly matches a token.
+ """
+ detector = BaseReasoningFormatDetector("", "")
+
+ # Build up the exact start token character by character
+ detector.parse_streaming_increment("<")
+ detector.parse_streaming_increment("t")
+ detector.parse_streaming_increment("h")
+ detector.parse_streaming_increment("i")
+ detector.parse_streaming_increment("n")
+ result = detector.parse_streaming_increment("k>")
+
+ # Should enter reasoning mode
+ self.assertEqual(result.normal_text, "")
+ self.assertEqual(result.reasoning_text, "")
+ self.assertTrue(detector._in_reasoning)
+ self.assertTrue(detector.stripped_think_start)
+
+
if __name__ == "__main__":
unittest.main()
|