From 51c38163c19cb64aee7727a60363d4f44108809b Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 31 Jul 2025 02:41:00 -0700 Subject: [PATCH] model: support Step3V (#8583) Signed-off-by: Xinyuan Tong Co-authored-by: nnnobody-code Co-authored-by: ispobock Co-authored-by: Qiaolin-Yu Co-authored-by: Qiaolin-Yu Co-authored-by: Xinyuan Tong Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> --- docs/backend/server_arguments.md | 2 +- python/sglang/srt/configs/__init__.py | 8 + python/sglang/srt/configs/model_config.py | 3 + python/sglang/srt/configs/step3_vl.py | 172 +++ python/sglang/srt/conversation.py | 23 + .../srt/function_call/function_call_parser.py | 2 + .../srt/function_call/step3_detector.py | 436 ++++++++ python/sglang/srt/hf_transformers_utils.py | 2 + python/sglang/srt/jinja_template_utils.py | 5 +- .../sglang/srt/managers/template_manager.py | 81 +- python/sglang/srt/models/step3_vl.py | 994 ++++++++++++++++++ .../multimodal/processors/base_processor.py | 2 + .../srt/multimodal/processors/step3_vl.py | 515 +++++++++ python/sglang/srt/reasoning_parser.py | 3 +- python/sglang/srt/server_args.py | 3 +- test/srt/test_reasoning_parser.py | 112 ++ 16 files changed, 2340 insertions(+), 23 deletions(-) create mode 100644 python/sglang/srt/configs/step3_vl.py create mode 100644 python/sglang/srt/function_call/step3_detector.py create mode 100644 python/sglang/srt/models/step3_vl.py create mode 100644 python/sglang/srt/multimodal/processors/step3_vl.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 985596292..636bb4f1b 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -148,7 +148,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--file-storage-path` | The path of the file storage in backend. | sglang_storage | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False | | `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None | -| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None | +| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None | ## Data parallelism diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 49d59b6f7..9c3008572 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig +from sglang.srt.configs.step3_vl import ( + Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig, +) __all__ = [ "ExaoneConfig", @@ -14,4 +19,7 @@ __all__ = [ "MultiModalityConfig", "KimiVLConfig", "MoonViTConfig", + "Step3VLConfig", + "Step3TextConfig", + "Step3VisionEncoderConfig", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 37722c492..37fbf07c7 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -335,6 +335,8 @@ class ModelConfig: "num_key_value_heads", # For ChatGLM: "multi_query_group_num", + # For Step3 + "num_attention_groups", ] for attr in attributes: num_kv_heads = getattr(self.hf_text_config, attr, None) @@ -644,6 +646,7 @@ multimodal_model_archs = [ "InternS1ForConditionalGeneration", "Phi4MMForCausalLM", "VILAForConditionalGeneration", + "Step3VLForConditionalGeneration", ] diff --git a/python/sglang/srt/configs/step3_vl.py b/python/sglang/srt/configs/step3_vl.py new file mode 100644 index 000000000..5519605c6 --- /dev/null +++ b/python/sglang/srt/configs/step3_vl.py @@ -0,0 +1,172 @@ +from typing import Any, Optional, Union + +from transformers.configuration_utils import PretrainedConfig + + +class Step3VisionEncoderConfig(PretrainedConfig): + model_type = "step3_vision_encoder" + + def __init__( + self, + hidden_size=1792, + intermediate_size=3072, + output_hidden_size=4096, + num_hidden_layers=63, + num_attention_heads=16, + num_channels=3, + image_size=728, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.output_hidden_size = output_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + super().__init__(**kwargs) + + +class Step3TextConfig(PretrainedConfig): + model_type = "step3_text" + architectures = ["Step3TextForCausalLM"] + + def __init__( + self, + hidden_size: int = 7168, + intermediate_size: int = 18432, + num_attention_heads: int = 64, + num_attention_groups: int = 1, + num_hidden_layers: int = 61, + max_seq_len: int = 65536, + vocab_size: int = 128815, + rms_norm_eps: float = 1e-5, + moe_intermediate_size: int = 5120, + moe_num_experts: int = 48, + moe_top_k: int = 3, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embedding: int = 65536, + share_expert_dim: int = 5120, + share_q_dim: int = 2048, + head_dim: int = 256, + norm_expert_weight: bool = False, + moe_layers_enum: tuple[int] = ( + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + ), + **kwargs, + ) -> None: + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_attention_groups = num_attention_groups + self.num_hidden_layers = num_hidden_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.rms_norm_eps = rms_norm_eps + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.max_position_embedding = max_position_embedding + self.share_expert_dim = share_expert_dim + self.share_q_dim = share_q_dim + self.head_dim = head_dim + self.norm_expert_weight = norm_expert_weight + self.moe_layers_enum = moe_layers_enum + + super().__init__(**kwargs) + + +class Step3VLConfig(PretrainedConfig): + model_type = "step3_vl" + + def __init__( + self, + vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None, + text_config: Optional[Union[dict, Step3TextConfig]] = None, + understand_projector_stride: int = 1, + projector_bias: bool = True, + image_token_id: int = 128001, + **kwargs, + ) -> None: + if vision_config is None: + vision_config = Step3VisionEncoderConfig() + elif isinstance(vision_config, dict): + vision_config = Step3VisionEncoderConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = Step3TextConfig() + elif isinstance(text_config, dict): + text_config = Step3TextConfig(**text_config) + self.text_config = text_config + + self.understand_projector_stride = understand_projector_stride + self.projector_bias = projector_bias + self.hidden_size = text_config.hidden_size + self.image_token_id = image_token_id + + super().__init__(**kwargs) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 81e406eb7..c34527591 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -994,6 +994,23 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="step3-vl", + system_message="<|begin▁of▁sentence|>You are a helpful assistant", + system_template="{system_message}\n", + roles=( + "<|BOT|>user\n", + "<|BOT|>assistant\n\n", + ), + sep="<|EOT|>", + sep_style=SeparatorStyle.NO_COLON_SINGLE, + stop_str="<|EOT|>", + image_token="", + # add_bos=True, + ) +) + @register_conv_template_matching_function def match_internvl(model_path: str): @@ -1103,3 +1120,9 @@ def match_vila(model_path: str): def match_mimo_vl(model_path: str): if re.search(r"mimo.*vl", model_path, re.IGNORECASE): return "mimo-vl" + + +# @register_conv_template_matching_function +# def match_step3(model_path: str): +# if re.search(r"step3", model_path, re.IGNORECASE): +# return "step3-vl" diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index bf6a3d959..6f6403de0 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector +from sglang.srt.function_call.step3_detector import Step3Detector logger = logging.getLogger(__name__) @@ -39,6 +40,7 @@ class FunctionCallParser: "kimi_k2": KimiK2Detector, "qwen3_coder": Qwen3CoderDetector, "glm45": Glm4MoeDetector, + "step3": Step3Detector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/step3_detector.py b/python/sglang/srt/function_call/step3_detector.py new file mode 100644 index 000000000..b46f4544f --- /dev/null +++ b/python/sglang/srt/function_call/step3_detector.py @@ -0,0 +1,436 @@ +import ast +import json +import logging +import re +from typing import Any, Dict, List + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer + +logger = logging.getLogger(__name__) + + +def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str: + """Get the expected type for a function argument from tool schema.""" + name2tool = {tool.function.name: tool for tool in defined_tools} + if func_name not in name2tool: + return None + tool = name2tool[func_name] + parameters = tool.function.parameters or {} + properties = parameters.get("properties", {}) + if arg_key not in properties: + return None + return properties[arg_key].get("type", None) + + +def parse_arguments(value: str) -> tuple[Any, bool]: + """Parse a string value to appropriate type. Returns (parsed_value, success).""" + try: + try: + parsed_value = json.loads(value) + except: + parsed_value = ast.literal_eval(value) + return parsed_value, True + except: + return value, False + + +class Step3Detector(BaseFormatDetector): + """ + Detector for Step3 model function call format. + + The Step3 format uses special Unicode tokens to delimit function calls + with steptml XML format for invocations. + + Format Structure: + ``` + <|tool_calls_begin|> + <|tool_call_begin|>function<|tool_sep|> + value1 + value2 + <|tool_call_end|> + <|tool_calls_end|> + ``` + """ + + def __init__(self): + super().__init__() + self.bot_token = "<|tool_calls_begin|>" + self.eot_token = "<|tool_calls_end|>" + self.tool_call_begin = "<|tool_call_begin|>" + self.tool_call_end = "<|tool_call_end|>" + self.tool_sep = "<|tool_sep|>" + + # Regex for parsing steptml invocations + self.invoke_regex = re.compile( + r'(.+?)', re.DOTALL + ) + self.param_regex = re.compile( + r'([^<]*)', re.DOTALL + ) + + # Streaming state variables + self._in_tool_block: bool = False + self._tool_block_finished: bool = False + self._current_function_name: str = "" + self._current_parameters: Dict[str, Any] = {} + self._in_tool_call: bool = False + self._function_name_sent: bool = False + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a Step3 format tool call.""" + return self.bot_token in text + + def _parse_steptml_invoke( + self, text: str, tools: List[Tool] = None + ) -> tuple[str, dict]: + """Parse steptml invoke format to extract function name and parameters.""" + invoke_match = self.invoke_regex.search(text) + if not invoke_match: + return None, {} + + func_name = invoke_match.group(1) + params_text = invoke_match.group(2) + + params = {} + for param_match in self.param_regex.finditer(params_text): + param_name = param_match.group(1) + param_value = param_match.group(2).strip() + + # If tools provided, use schema-aware parsing + if tools: + arg_type = get_argument_type(func_name, param_name, tools) + if arg_type and arg_type != "string": + parsed_value, _ = parse_arguments(param_value) + params[param_name] = parsed_value + else: + params[param_name] = param_value + else: + # Fallback to generic parsing if no tools provided + parsed_value, _ = parse_arguments(param_value) + params[param_name] = parsed_value + + return func_name, params + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + """ + if self.bot_token not in text: + return StreamingParseResult(normal_text=text, calls=[]) + + try: + pre_text, rest = text.split(self.bot_token, 1) + + # If no end token, return everything as normal text + if self.eot_token not in rest: + return StreamingParseResult(normal_text=text, calls=[]) + + tool_section, post_text = rest.split(self.eot_token, 1) + + # Find all individual tool calls using regex + calls = [] + tool_call_pattern = ( + f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}" + ) + + for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL): + call_content = match.group(1) + + # Check if it's a function call + if self.tool_sep not in call_content: + continue + + type_part, invoke_part = call_content.split(self.tool_sep, 1) + if type_part.strip() != "function": + continue + + func_name, params = self._parse_steptml_invoke(invoke_part, tools) + if func_name: + # Use parse_base_json to create the ToolCallItem + action = {"name": func_name, "arguments": params} + calls.extend(self.parse_base_json(action, tools)) + + # Combine pre and post text + normal_text = pre_text + post_text + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + # Return the original text if parsing fails + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing for Step3 format. + """ + self._buffer += new_text + + # Build tool indices for validation + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + # If we've finished the tool block, everything is normal text + if self._tool_block_finished: + normal_text = self._buffer + self._buffer = "" + return StreamingParseResult(normal_text=normal_text) + + # Check if tool block hasn't started yet + if not self._in_tool_block: + if self.bot_token in self._buffer: + idx = self._buffer.find(self.bot_token) + normal_text = self._buffer[:idx] + self._buffer = self._buffer[idx + len(self.bot_token) :] + self._in_tool_block = True + return StreamingParseResult(normal_text=normal_text) + else: + # Check if we might have a partial bot_token + partial_len = self._ends_with_partial_token( + self._buffer, self.bot_token + ) + if partial_len: + return StreamingParseResult() # Wait for more text + else: + normal_text = self._buffer + self._buffer = "" + return StreamingParseResult(normal_text=normal_text) + + # We're inside the tool block + calls: List[ToolCallItem] = [] + + # Check if tool block is ending + if self.eot_token in self._buffer: + idx = self._buffer.find(self.eot_token) + + # If we're in the middle of a tool call, we need to handle it + if self._in_tool_call: + # The buffer before eot_token might contain the end of the current tool call + before_eot = self._buffer[:idx] + if self.tool_call_end in before_eot: + # Parse this final tool call + result = self._parse_partial_tool_call(tools) + calls.extend(result.calls) + else: + # Incomplete tool call - log warning + logger.warning("Tool block ended with incomplete tool call") + + remaining = self._buffer[idx + len(self.eot_token) :] + self._buffer = "" + self._tool_block_finished = True + + # Reset any partial tool call state + self._reset_streaming_state() + + return StreamingParseResult(normal_text=remaining, calls=calls) + + # Check if we're in a tool call or need to start one + if not self._in_tool_call: + if self.tool_call_begin in self._buffer: + idx = self._buffer.find(self.tool_call_begin) + # Remove any content before tool call begin (shouldn't happen but be safe) + self._buffer = self._buffer[idx + len(self.tool_call_begin) :] + self._in_tool_call = True + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + # Fall through to parse the partial tool call + else: + # Wait for tool call to begin + return StreamingParseResult() + + # Parse partial tool call + if self._in_tool_call: + return self._parse_partial_tool_call(tools) + + return StreamingParseResult() + + def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult: + """Parse partial tool call for streaming scenarios.""" + calls = [] + + # Check if we have tool_sep (means we're past the type declaration) + if self.tool_sep not in self._buffer: + return StreamingParseResult(calls=calls) # Wait for more text + + type_part, invoke_part = self._buffer.split(self.tool_sep, 1) + if type_part.strip() != "function": + # Invalid tool type, skip this tool call + self._reset_streaming_state() + return StreamingParseResult(calls=calls) + + # Try to extract function name if not sent yet + if not self._function_name_sent: + name_match = re.search(r'', invoke_part) + if name_match: + func_name = name_match.group(1) + + # Validate function name + if func_name in self._tool_indices: + self._current_function_name = func_name + self._function_name_sent = True + + # Initialize tool tracking + if self.current_tool_id == -1: + self.current_tool_id = 0 + + # Ensure tracking arrays are large enough + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Store tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + # Send tool name with empty parameters + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + else: + # Invalid function name + logger.warning(f"Invalid function name: {func_name}") + self._reset_streaming_state() + return StreamingParseResult(calls=calls) + else: + # Function name not complete yet + return StreamingParseResult(calls=calls) + + # Parse parameters incrementally + if self._function_name_sent: + # Extract all complete parameters + new_params = {} + for param_match in self.param_regex.finditer(invoke_part): + param_name = param_match.group(1) + param_value = param_match.group(2).strip() + + # Use schema-aware parsing + arg_type = get_argument_type( + self._current_function_name, param_name, tools + ) + if arg_type and arg_type != "string": + parsed_value, _ = parse_arguments(param_value) + new_params[param_name] = parsed_value + else: + new_params[param_name] = param_value + + # Check if we have new parameters to stream + if new_params != self._current_parameters: + # Build the JSON content without the closing brace for streaming + if not self._current_parameters: + # First parameters - send opening brace and content + params_content = json.dumps(new_params, ensure_ascii=False) + if len(params_content) > 2: # More than just "{}" + # Send everything except the closing brace + diff = params_content[:-1] + else: + diff = "{" + else: + # Subsequent parameters - calculate the incremental diff + old_json = json.dumps(self._current_parameters, ensure_ascii=False) + new_json = json.dumps(new_params, ensure_ascii=False) + + # Remove closing braces for comparison + old_without_brace = old_json[:-1] + new_without_brace = new_json[:-1] + + # The new content should extend the old content + if new_without_brace.startswith(old_without_brace): + diff = new_without_brace[len(old_without_brace) :] + else: + # Parameters changed in unexpected way - shouldn't happen in normal streaming + diff = "" + + if diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + parameters=diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + + # Update current state + self._current_parameters = new_params + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params + + # Check if tool call is complete + if self.tool_call_end in self._buffer: + # Send closing brace if we've sent any parameters + if self.streamed_args_for_tool[self.current_tool_id]: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + parameters="}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += "}" + + # Find the end position + end_idx = self._buffer.find(self.tool_call_end) + # Remove the processed tool call from buffer + self._buffer = self._buffer[end_idx + len(self.tool_call_end) :] + + # Reset state for next tool call + self._reset_streaming_state() + self.current_tool_id += 1 + + return StreamingParseResult(calls=calls) + + def _reset_streaming_state(self): + """Reset streaming state for the next tool call""" + self._in_tool_call = False + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + + def supports_structural_tag(self) -> bool: + """Return True if this detector supports structural tag format.""" + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() + + def build_ebnf(self, tools: List[Tool]) -> str: + """ + Build EBNF grammar for Step3 tool call format. + """ + # Custom call rule for steptml format + call_rule_fmt = ( + '"function" "<|tool_sep|>" "" ' + '{arguments_rule} ""' + ) + + # Custom key-value rule for steptml parameters + key_value_rule_fmt = ( + '"" {valrule} ""' + ) + + return EBNFComposer.build_ebnf( + tools, + sequence_start_token=self.bot_token, + sequence_end_token=self.eot_token, + individual_call_start_token=self.tool_call_begin, + individual_call_end_token=self.tool_call_end, + tool_call_separator="", + function_format="xml", + call_rule_fmt=call_rule_fmt, + key_value_rule_fmt=key_value_rule_fmt, + key_value_separator="", + ) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 7c056acdd..bf16addc5 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -41,6 +41,7 @@ from sglang.srt.configs import ( ExaoneConfig, KimiVLConfig, MultiModalityConfig, + Step3VLConfig, ) from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector @@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { MultiModalityConfig.model_type: MultiModalityConfig, KimiVLConfig.model_type: KimiVLConfig, InternVLChatConfig.model_type: InternVLChatConfig, + Step3VLConfig.model_type: Step3VLConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/jinja_template_utils.py b/python/sglang/srt/jinja_template_utils.py index 9a944c994..ac55699dc 100644 --- a/python/sglang/srt/jinja_template_utils.py +++ b/python/sglang/srt/jinja_template_utils.py @@ -165,7 +165,7 @@ def process_content_for_template_format( new_msg["content"] = processed_content_parts return new_msg - else: # content_format == "string" + elif content_format == "string": # String format: flatten to text only (for templates like DeepSeek) text_parts = [] for chunk in msg_dict["content"]: @@ -179,3 +179,6 @@ def process_content_for_template_format( new_msg["content"] = " ".join(text_parts) if text_parts else "" new_msg = {k: v for k, v in new_msg.items() if v is not None} return new_msg + + else: + raise ValueError(f"Invalid content format: {content_format}") diff --git a/python/sglang/srt/managers/template_manager.py b/python/sglang/srt/managers/template_manager.py index 4684bf1a0..e340f65f0 100644 --- a/python/sglang/srt/managers/template_manager.py +++ b/python/sglang/srt/managers/template_manager.py @@ -53,7 +53,7 @@ class TemplateManager: def __init__(self): self._chat_template_name: Optional[str] = None self._completion_template_name: Optional[str] = None - self._jinja_template_content_format: Optional[str] = None + self._jinja_template_content_format: Optional[str] = "openai" @property def chat_template_name(self) -> Optional[str]: @@ -71,31 +71,60 @@ class TemplateManager: return self._jinja_template_content_format def load_chat_template( - self, tokenizer_manager, chat_template_arg: str, model_path: str + self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str ) -> None: """ Load a chat template from various sources. Args: tokenizer_manager: The tokenizer manager instance - chat_template_arg: Template name or file path + chat_template_arg: Template name, file path, or None to auto-detect model_path: Path to the model """ - logger.info(f"Loading chat template: {chat_template_arg}") + if chat_template_arg: + self._load_explicit_chat_template(tokenizer_manager, chat_template_arg) + else: + # Try HuggingFace template first + hf_template = self._resolve_hf_chat_template(tokenizer_manager) + if hf_template: + self._jinja_template_content_format = ( + detect_jinja_template_content_format(hf_template) + ) + logger.info( + f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}" + ) + return - if not chat_template_exists(chat_template_arg): - if not os.path.exists(chat_template_arg): - raise RuntimeError( - f"Chat template {chat_template_arg} is not a built-in template name " - "or a valid chat template file path." + # Fallback to SGLang template guessing + self.guess_chat_template_from_model_path(model_path) + + # Set default format if no template was found + if self._chat_template_name is None: + self._jinja_template_content_format = "string" + logger.info( + "No chat template found, defaulting to 'string' content format" ) - if chat_template_arg.endswith(".jinja"): - self._load_jinja_template(tokenizer_manager, chat_template_arg) - else: - self._load_json_chat_template(chat_template_arg) - else: + def _load_explicit_chat_template( + self, tokenizer_manager, chat_template_arg: str + ) -> None: + """Load explicitly specified chat template.""" + logger.info(f"Loading chat template from argument: {chat_template_arg}") + + if chat_template_exists(chat_template_arg): self._chat_template_name = chat_template_arg + return + + if not os.path.exists(chat_template_arg): + raise RuntimeError( + f"Chat template {chat_template_arg} is not a built-in template name " + "or a valid chat template file path." + ) + + if chat_template_arg.endswith(".jinja"): + self._load_jinja_template(tokenizer_manager, chat_template_arg) + else: + self._load_json_chat_template(chat_template_arg) def guess_chat_template_from_model_path(self, model_path: str) -> None: """ @@ -146,10 +175,7 @@ class TemplateManager: completion_template: Optional completion template name/path """ # Load chat template - if chat_template: - self.load_chat_template(tokenizer_manager, chat_template, model_path) - else: - self.guess_chat_template_from_model_path(model_path) + self.load_chat_template(tokenizer_manager, chat_template, model_path) # Load completion template if completion_template: @@ -166,7 +192,7 @@ class TemplateManager: chat_template ) logger.info( - f"Detected chat template content format: {self._jinja_template_content_format}" + f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}" ) def _load_json_chat_template(self, template_path: str) -> None: @@ -224,3 +250,20 @@ class TemplateManager: override=True, ) self._completion_template_name = template["name"] + + def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]: + """ + Resolve HuggingFace chat template. + + Returns the chat template string if found, None otherwise. + """ + tokenizer = tokenizer_manager.tokenizer + + # Try to get AutoTokenizer chat template + try: + return tokenizer.get_chat_template() + except Exception as e: + logger.debug(f"Error getting chat template via get_chat_template(): {e}") + + logger.debug("No HuggingFace chat template found") + return None diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py new file mode 100644 index 000000000..3ed0a153f --- /dev/null +++ b/python/sglang/srt/models/step3_vl.py @@ -0,0 +1,994 @@ +import logging +import math +from collections.abc import Iterable +from math import sqrt +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from torch.nn import functional as F +from transformers import PretrainedConfig +from transformers.activations import ACT2FN + +from sglang.srt.configs.step3_vl import ( + Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig, +) +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + global_server_args_dict, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers + +logger = logging.getLogger(__name__) + + +""" +Text Model +""" + + +class Step3TextMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Step3TextMoEMLP(nn.Module): + # Native + def __init__( + self, + layer_id: int, + config: Step3TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.layer_id = layer_id + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}." + ) + + self.topk = TopK( + top_k=config.moe_top_k, + renormalize=config.norm_expert_weight, + use_grouped_topk=False, + ) + + self.experts = get_moe_impl_class()( + num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("experts", prefix), + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + output_size=config.moe_num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), + ) + + if global_server_args_dict["enable_deepep_moe"]: + raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits, _ = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, topk_output=topk_output + ) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Step3TextAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + share_q_dim: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + rms_norm_eps=None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + self.all_tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + self.attn_tp_rank = attn_tp_rank + self.layer_id = layer_id + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) + self.head_dim = head_dim + self.q_size = share_q_dim if share_q_dim else head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = MergedColumnParallelLinear( + hidden_size, + [self.q_size, self.kv_size, self.kv_size], + bias=False, + quant_config=quant_config, + tp_rank=0, # In fact, we need a MergedReplicatedLinear + tp_size=1, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + reduce_results=False, + prefix=add_prefix("o_proj", prefix), + ) + + self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps) + + self.wq = ColumnParallelLinear( + self.q_size, + self.head_dim * self.total_num_heads, + bias=False, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=add_prefix("wq", prefix), + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.inter_norm(q.contiguous()) + q, _ = self.wq(q) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Step3TextDecoderLayer(nn.Module): + def __init__( + self, + config: Step3TextConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + # TODO: support shared experts fusion + # self.n_shared_experts = 1 + # self.num_fused_shared_experts = ( + # 0 + # if global_server_args_dict["disable_shared_experts_fusion"] + # else self.n_shared_experts + # ) + self.num_fused_shared_experts = 0 + rms_norm_eps = config.rms_norm_eps + self.self_attn = Step3TextAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=1, + head_dim=head_dim, + share_q_dim=config.share_q_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=rms_norm_eps, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + self.use_moe = False + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_id = layer_id + self.is_layer_sparse = True if layer_id in moe_layers_idx else False + self.is_previous_layer_sparse = ( + True if layer_id - 1 in moe_layers_idx else False + ) + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=self.is_previous_layer_sparse, + ) + + if not self.is_layer_sparse: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + else: + self.use_moe = True + if self.num_fused_shared_experts == 0: + self.moe = Step3TextMoEMLP( + layer_id=layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.share_expert = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=add_prefix("share_expert", prefix), + ) + else: + self.moe = Step3TextMoEMLP( + layer_id=layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + + def moe_mlp_forward(self, hidden_states): + if not self.num_fused_shared_experts: + h = hidden_states.clone() + hidden_states = self.moe(hidden_states) + hidden_states += self.share_expert(h) + else: + hidden_states = self.moe(hidden_states) + return hidden_states + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + if self.use_moe: + hidden_states = self.moe_mlp_forward(hidden_states) + else: + hidden_states = self.mlp(hidden_states) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +class Step3TextModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + prefix=add_prefix("embed_tokens", prefix), + ) + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Step3TextDecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self): + return self.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +""" +Vision Model +""" + + +def get_abs_pos(abs_pos, tgt_size): + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) + return vision_pos_embed + else: + return abs_pos + + +class Step3VisionMLP(nn.Module): + def __init__( + self, + dim: int, + intermediate_size: int, + bias: bool = True, + hidden_act="quick_gelu", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.fc1 = ColumnParallelLinear( + dim, + intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_proj", prefix), + ) + self.act = ACT2FN[hidden_act] # quick_gelu + self.fc2 = RowParallelLinear( + intermediate_size, + dim, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Step3VisionAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 16, + qkv_backend="fa3", + quant_config=None, + prefix: str = "", + ) -> None: + + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.out_proj = RowParallelLinear( + dim, + dim, + bias=True, + quant_config=quant_config, + prefix=add_prefix("out_proj", prefix), + ) + self.scale = self.head_dim**-0.5 + + self.attn = VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + use_qkv_parallel=True, + rotary_embed="normal", + proj_bias=True, + qkv_backend=qkv_backend, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attn_output = self.attn(hidden_states) + return attn_output + + +class Step3VisionEmbeddings(nn.Module): + + def __init__(self, config: Step3VisionEncoderConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.pad_tp_size = 4 # hard code for padding + # To load the pretrained weights, we still use P+1 as the seqlen + self.position_embedding = torch.nn.Embedding( + self.num_patches + 1, self.embed_dim + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_patches + 1).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + # pad + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + get_abs_pos( + self.position_embedding(self.position_ids), patch_embeds.size(1) + ) + embeddings = torch.cat( + [ + embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1), + embeddings, + ], + dim=1, + ) + return embeddings + + +class Step3VisionEncoderLayer(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6) + self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6) + + self.self_attn = Step3VisionAttention( + self.embed_dim, num_heads=config.num_attention_heads + ) + self.mlp = Step3VisionMLP( + dim=self.embed_dim, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states)) + hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states)) + return hidden_states + + +class Step3VisionTransformer(nn.Module): + def __init__(self, config: Step3VisionEncoderConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + self.embeddings = Step3VisionEmbeddings(config) + self.transformer = Step3VisionEncoder(config) + + @property + def dtype(self) -> torch.dtype: + return self.embeddings.patch_embedding.weight.dtype + + def forward( + self, + pixel_values: torch.Tensor, + ): + hidden_states = self.embeddings(pixel_values) + hidden_states = self.transformer(inputs_embeds=hidden_states) + return hidden_states + + +class Step3VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Step3VisionEncoderLayer`]. + + Args: + config: StepVisionEncoderConfig + """ + + def __init__(self, config: Step3VisionEncoderConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + inputs_embeds, + ) -> torch.Tensor: + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + ) + + return hidden_states + + +class Step3VLForConditionalGeneration(nn.Module): + + def __init__( + self, + config: Step3VLConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Step3TextModel( + config.text_config, quant_config, prefix=add_prefix("model", prefix) + ) + + self.vision_model = Step3VisionTransformer(config.vision_config) + + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride, + ) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) + + # TODO: support shared experts fusion + # self.n_shared_experts = 1 + # self.num_fused_shared_experts = ( + # 0 + # if global_server_args_dict["disable_shared_experts_fusion"] + # else self.n_shared_experts + # ) + self.num_fused_shared_experts = 0 + self.config.tie_word_embeddings = False + if getattr(self.config, "tie_word_embeddings", False): + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config.text_config) + + def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor: + return self.vision_model(input_tensor)[:, 4:] + + def _flatten_embeddings(self, embeddings) -> torch.Tensor: + + if isinstance(embeddings, torch.Tensor): + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) + + return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings)) + + def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: + B, P = image_features.shape[:2] + HW = int(sqrt(P)) + image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) + image_features = self.vit_downsampler(image_features) + image_features = self.vit_downsampler2(image_features) + n_dim = image_features.size(1) + image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1) + image_features = self.vit_large_projector(image_features) + return image_features + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + assert len(items) == 1 # We only have images. + + item = items[0] + pixel_values = item.feature.type(self.vision_model.dtype) + num_patches = item.model_specific_data.get("num_patches") + patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None) + if patch_pixel_values is not None: + patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype) + + if patch_pixel_values is not None: + patch_pixel_values = patch_pixel_values.to("cuda") + + image_features = self._get_vision_model_output(pixel_values) + patch_image_features = ( + self._get_vision_model_output(patch_pixel_values) + if patch_pixel_values is not None + else None + ) + + image_features = self._process_image_features(image_features) + patch_image_features = ( + self._process_image_features(patch_image_features) + if patch_image_features is not None + else None + ) + + merged_image_features = [] + cur_patch_idx = 0 + for i, num_patch in enumerate(num_patches): + cur_feature = [] + if num_patch > 0: + patch_slice = patch_image_features[ + cur_patch_idx : cur_patch_idx + num_patch + ] + cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) + cur_feature.append(image_features[i].view(-1, image_features.shape[-1])) + cur_patch_idx += num_patch + merged_image_features.append( + torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0] + ) + return self._flatten_embeddings(merged_image_features) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + }, + positions=positions, + ) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # TODO: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", 0), + (".qkv_proj", ".k_proj", 1), + (".qkv_proj", ".v_proj", 2), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.text_config.moe_num_experts + + self.num_fused_shared_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params = set() + + def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool: + name_parts = name_path.split(".") + weight_parts = weight_path.split(".") + shard_id_matches = name_parts[4] == weight_parts[2] + return shard_id_matches + + for name, loaded_weight in weights: + if "vision_model" in name: + # 1.It’s not great, but let’s leave it like this for now + name = name.replace("self_attn", "self_attn.attn") + # 2. + name = name.replace("out_proj", "proj") + + # TODO: support vision model + if self.num_fused_shared_experts > 0 and "share" in name: + # assert False + name = name.replace("share_expert", "moe") + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if ( + expert_id != self.config.text_config.moe_num_experts + or not match_expert_and_shard_ids(name, weight_name) + ): + continue + + part_name = weight_name.split(".")[-2] + fake_weight_name = name.replace(part_name, weight_name[:-1]) + actual_param_name = name.replace(part_name + ".", param_name) + param = params_dict[actual_param_name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "gate." not in name and "moe" in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if "moe" not in name: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + else: + if "gate." in name: + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if expert_id == self.config.text_config.moe_num_experts: + continue + if not match_expert_and_shard_ids(name, weight_name): + continue + part_name = weight_name.split(".")[-2] + fake_weight_name = name.replace(part_name, weight_name[:-1]) + actual_param_name = name.replace(part_name + ".", param_name) + param = params_dict[actual_param_name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight[expert_id], + name, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(actual_param_name) + # Don't break here, because this 'loaded_weight' includes all the weights for this layer + + @classmethod + def get_model_config_for_expert_location(cls, config: Step3VLConfig): + return ModelConfigForExpertLocation( + num_layers=config.text_config.num_hidden_layers, + num_logical_experts=config.text_config.moe_num_experts, + num_groups=None, + ) + + +EntryClass = Step3VLForConditionalGeneration diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index c98720652..06e5c0da0 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -176,6 +176,8 @@ class BaseMultimodalProcessor(ABC): "image_grid_hws": Modality.IMAGE, "aspect_ratio_ids": Modality.IMAGE, "aspect_ratio_mask": Modality.IMAGE, + "num_patches": Modality.IMAGE, + "patch_pixel_values": Modality.IMAGE, # Audio-related attributes "audio_features": Modality.AUDIO, "audio_feature_lens": Modality.AUDIO, diff --git a/python/sglang/srt/multimodal/processors/step3_vl.py b/python/sglang/srt/multimodal/processors/step3_vl.py new file mode 100644 index 000000000..4ed09635b --- /dev/null +++ b/python/sglang/srt/multimodal/processors/step3_vl.py @@ -0,0 +1,515 @@ +import math +import re +from itertools import product +from typing import List, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import BatchFeature, TensorType + +from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor as SGLangBaseProcessor, +) +from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens + +ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] + + +class GPUToTensor(torch.nn.Module): + + def forward(self, raw_image: Union[np.ndarray, Image.Image]) -> torch.Tensor: + if isinstance(raw_image, Image.Image): + return transforms.ToTensor()(raw_image) + if raw_image.ndim == 2: + raw_image = raw_image[:, :, None].repeat(3, -1) + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + image_tensor = torch.from_numpy(raw_image).to(device) + image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous() + if image_tensor.dtype == torch.uint8: + image_tensor = image_tensor.to(torch.float32).div(255) + return image_tensor + + +class Step3VisionProcessor: + def __init__(self, size, interpolation_mode="bicubic", patch_size=None): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + patch_size = patch_size if patch_size is not None else size + + self.transform = transforms.Compose( + [ + GPUToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (size, size), + interpolation=( + InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR + ), + antialias=True, + ), + ] + ) + + self.patch_transform = ( + transforms.Compose( + [ + GPUToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (patch_size, patch_size), + interpolation=( + InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR + ), + antialias=True, + ), + ] + ) + if patch_size is not None + else None + ) + + def __call__(self, image, is_patch=False): + if is_patch: + return {"pixel_values": self.patch_transform(image).unsqueeze(0)} + else: + return {"pixel_values": self.transform(image).unsqueeze(0)} + + +class ImagePatcher: + + def determine_window_size(self, long: int, short: int) -> int: + if long <= 728: + return short if long / short > 1.5 else 0 + return min(short, 504) if long / short > 4 else 504 + + def slide_window( + self, + width: int, + height: int, + sizes: list[tuple[int, int]], + steps: list[tuple[int, int]], + img_rate_thr: float = 0.6, + ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]: + assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1" + windows = [] + # Sliding windows. + for size, step in zip(sizes, steps): + size_w, size_h = size + step_w, step_h = step + + x_num = 1 if width <= size_w else math.ceil((width - size_w) / step_w + 1) + x_start = [step_w * i for i in range(x_num)] + if len(x_start) > 1 and x_start[-1] + size_w > width: + x_start[-1] = width - size_w + + y_num = 1 if height <= size_h else math.ceil((height - size_h) / step_h + 1) + y_start = [step_h * i for i in range(y_num)] + if len(y_start) > 1 and y_start[-1] + size_h > height: + y_start[-1] = height - size_h + + start = np.array(list(product(y_start, x_start)), dtype=int) + start[:, [0, 1]] = start[:, [1, 0]] + windows.append(np.concatenate([start, start + size], axis=1)) + windows = np.concatenate(windows, axis=0) + + return [ + (int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1])) + for box in windows + ], (x_num, y_num) + + def square_pad(self, img: Image.Image) -> Image.Image: + w, h = img.size + if w == h: + return img + size = max(w, h) + padded = Image.new(img.mode, (size, size), 0) + padded.paste(img, (0, 0)) + return padded + + def get_image_size_for_padding( + self, img_width: int, img_height: int + ) -> tuple[int, int]: + ratio = img_width / img_height + if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): + new_size = max(img_height, img_width) + return new_size, new_size + return img_width, img_height + + def get_image_size_for_preprocess( + self, img_width: int, img_height: int + ) -> tuple[int, int]: + + if max(img_height, img_width) > 3024: + scale_factor = 3024 / max(img_height, img_width) + img_width = int(img_width * scale_factor) + img_height = int(img_height * scale_factor) + return img_width, img_height + else: + return img_width, img_height + + def get_image_size_for_crop( + self, img_width: int, img_height: int, window_size: int + ): + w_ratio = img_width / window_size + h_ratio = img_height / window_size + + if w_ratio < 1: + width_new = img_width + else: + decimal_w = w_ratio - img_width // window_size + w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio) + width_new = window_size * w_ratio + if h_ratio < 1: + height_new = img_height + else: + decimal_h = h_ratio - img_height // window_size + h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio) + height_new = window_size * h_ratio + return int(width_new), int(height_new) + + def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): + target = img.crop((j, i, j + tw, i + th)) + return target + + def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]: + img_width, img_height = self.get_image_size_for_padding(img_width, img_height) + img_width, img_height = self.get_image_size_for_preprocess( + img_width, img_height + ) + window_size = self.determine_window_size( + max(img_height, img_width), min(img_height, img_width) + ) + if window_size == 0: + return 0, 0 + else: + img_width, img_height = self.get_image_size_for_crop( + img_width, img_height, window_size + ) + center_list, (x_num, y_num) = self.slide_window( + img_width, + img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) + full_rows = (len(center_list) - 1) // x_num + 1 + if len(center_list) > 0 and len(center_list) % x_num == 0: + full_rows -= 1 + return len(center_list), full_rows + + def __call__( + self, img: Image.Image + ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: + img_width, img_height = img.size + new_img_width, new_img_height = self.get_image_size_for_padding( + img_width, img_height + ) + if new_img_width != img_width or new_img_height != img_height: + img = self.square_pad(img) + img_width, img_height = img.size + + new_img_width, new_img_height = self.get_image_size_for_preprocess( + img_width, img_height + ) + img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR) + window_size = self.determine_window_size( + max(new_img_height, new_img_width), min(new_img_height, new_img_width) + ) + if window_size == 0: + return img, [], None + else: + new_img_width, new_img_height = self.get_image_size_for_crop( + new_img_width, new_img_height, window_size + ) + if (new_img_width, new_img_height) != (img_width, img_height): + img_for_crop = img.resize( + (new_img_width, new_img_height), Image.Resampling.BILINEAR + ) + else: + img_for_crop = img + + patches = [] + newlines = [] + center_list, (x_num, y_num) = self.slide_window( + new_img_width, + new_img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) + for patch_id, center_lf_point in enumerate(center_list): + x, y, patch_w, patch_h = center_lf_point + big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w) + patches.append(big_patch) + if (patch_id + 1) % x_num == 0: + newlines.append(patch_id) + + if newlines and newlines[-1] == len(patches) - 1: + newlines.pop() + + return ( + img, + patches, + ( + [i in newlines for i in range(len(patches))] + if len(patches) > 0 + else None + ), + ) + + +class Step3VLProcessor: + def __init__( + self, + config, + tokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + self.image_size = 728 + self.patch_size = 504 + self.image_preprocessor = Step3VisionProcessor( + self.image_size, "bilinear", self.patch_size + ) + + self.num_image_feature_size = 169 + self.num_patch_feature_size = 81 + self.image_token = "" + self.image_feature_placeholder = self.image_token * self.num_image_feature_size + self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size + + self.patcher = ImagePatcher() + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[self.image_token] + + def get_num_image_tokens(self, img_width: int, img_height: int) -> int: + num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height) + + return ( + num_patches * (self.num_patch_feature_size + 2) + + self.num_image_feature_size + + 2 + + num_newlines + ) + + def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]: + result = [] + for img in images: + result.append(self.patcher(img)) + return result + + def _convert_images_to_pixel_values( + self, + images: list[Image.Image], + is_patch: bool = False, + ) -> list[torch.Tensor]: + return [ + self.image_preprocessor(img, is_patch=is_patch)["pixel_values"] + for img in images + ] + + def _get_patch_repl( + self, + num_patches: int, + patch_newline_mask: list[bool] | None, + ) -> tuple[str, list[int]]: + text = "" + token_ids = [] + for i in range(num_patches): + assert len(patch_newline_mask) == num_patches + text += f"{self.patch_feature_placeholder}" + token_ids.extend( + [self.tokenizer.convert_tokens_to_ids("")] + + [self.image_token_id] * self.num_patch_feature_size + + [self.tokenizer.convert_tokens_to_ids("")] + ) + if patch_newline_mask and patch_newline_mask[i]: + text += "" + token_ids.append( + self.tokenizer.convert_tokens_to_ids("") + ) + return text, token_ids + + def _get_image_repl( + self, + num_images: int, + ) -> tuple[str, list[int]]: + text = f"{self.image_feature_placeholder}" + token_ids = ( + [self.tokenizer.convert_tokens_to_ids("")] + + [self.image_token_id] * self.num_image_feature_size + + [self.tokenizer.convert_tokens_to_ids("")] + ) + return text * num_images, token_ids * num_images + + def _get_image_repl_features( + self, + num_images: int, + num_patches: int, + patch_new_line_idx: Optional[list[bool]], + ) -> tuple[str, list[int]]: + if num_patches > 0: + patch_repl, patch_repl_ids = self._get_patch_repl( + num_patches, patch_new_line_idx + ) + else: + patch_repl = "" + patch_repl_ids = [] + image_repl, image_repl_ids = self._get_image_repl(num_images) + return patch_repl + image_repl, patch_repl_ids + image_repl_ids + + def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str: + parts = text.split(placeholder) + + if len(parts) - 1 != len(repls): + raise ValueError( + "The number of placeholders does not match the number of replacements." # noqa: E501 + ) + + result = [parts[0]] + for i, repl in enumerate(repls): + result.append(repl) + result.append(parts[i + 1]) + + return "".join(result) + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + *args, + **kwargs, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + text_inputs = self.tokenizer(text) + else: + splitted_images_data = self._split_images(images) + pixel_values_lst = [] + patch_pixel_values_lst = [] + patch_newline_mask_lst = [] + image_repl_str_lst = [] + image_repl_ids_lst = [] + num_patches = [] + for ( + raw_img, + img_patches, + patch_newline_mask, + ) in splitted_images_data: # noqa: E501 + pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img])) + + if len(img_patches) > 0: + patch_pixel_values_lst.extend( + self._convert_images_to_pixel_values(img_patches, is_patch=True) + ) + num_patches.append(len(img_patches)) + + image_repl_str, image_repl_ids = self._get_image_repl_features( + 1, len(img_patches), patch_newline_mask + ) + image_repl_str_lst.append(image_repl_str) + image_repl_ids_lst.extend(image_repl_ids) + + if patch_newline_mask is not None: + patch_newline_mask_lst.extend(patch_newline_mask) + + image_inputs = { + "pixel_values": torch.cat(pixel_values_lst), + "num_patches": num_patches, + } + if patch_pixel_values_lst: + image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst) + if patch_newline_mask_lst: + image_inputs["patch_newline_mask"] = torch.tensor( + patch_newline_mask_lst, dtype=torch.bool + ) + + text = [ + self.replace_placeholder(t, self.image_token, image_repl_str_lst) + for t in text + ] + text_inputs = self.tokenizer(text) + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +################################################ + + +class Step3VLImageProcessor(SGLangBaseProcessor): + models = [Step3VLForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + # TODO, check _processor is tokenizer or processor. + processor = Step3VLProcessor(hf_config, _processor) + super().__init__(hf_config, server_args, processor, *args, **kwargs) + self.IM_TOKEN_ID = 128001 + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=128001, + image_token_regex=re.compile(r"(?:)"), + ).build(_processor) + + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + def preprocess(self, image): + return {"pixel_values": self.transform(image).unsqueeze(0)} + + def __call__(self, image): + return self.preprocess(image) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text: str | List[int], + request_obj, + *args, + **kwargs, + ): + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + video_data=request_obj.video_data, + multimodal_tokens=self.mm_tokens, + ) + + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + "im_token_id": self.mm_tokens.image_token_id, + } diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index b5b737856..a2561a18d 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -105,7 +105,7 @@ class BaseReasoningFormatDetector: # If we're not in a reasoning block return as normal text if not self._in_reasoning: self._buffer = "" - return StreamingParseResult(normal_text=new_text) + return StreamingParseResult(normal_text=current_text) return StreamingParseResult() @@ -233,6 +233,7 @@ class ReasoningParser: "qwen3-thinking": Qwen3ThinkingDetector, "glm45": Qwen3Detector, "kimi": KimiDetector, + "step3": DeepSeekR1Detector, } def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 992905437..9e673a9f4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1117,9 +1117,10 @@ class ServerArgs: "kimi_k2", "qwen3_coder", "glm45", + "step3", ], default=ServerArgs.tool_call_parser, - help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.", + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.", ) # Data parallelism diff --git a/test/srt/test_reasoning_parser.py b/test/srt/test_reasoning_parser.py index 7f3359144..97eea82b4 100644 --- a/test/srt/test_reasoning_parser.py +++ b/test/srt/test_reasoning_parser.py @@ -493,5 +493,117 @@ class TestIntegrationScenarios(CustomTestCase): self.assertIn("final answer", all_normal) +class TestBufferLossBugFix(CustomTestCase): + """Test cases for the buffer loss bug fix in parse_streaming_increment.""" + + def test_partial_end_tag_buffer_loss_bug(self): + """ + Test the bug where partial end tag fragments are lost when followed by normal text. + + Bug scenario: + 1. _in_reasoning is False + 2. new_text is "", "") + + # Step 1: Send partial end tag when not in reasoning mode + # This should be buffered since it could be start of "" + result1 = detector.parse_streaming_increment("", "") + + # Send partial start tag + result1 = detector.parse_streaming_increment("", "") + + # Enter reasoning mode + detector.parse_streaming_increment("") + detector.parse_streaming_increment("some reasoning") + + # Send partial end tag + result1 = detector.parse_streaming_increment("normal text") + self.assertEqual(result2.normal_text, "normal text") + # The reasoning text should be empty since buffer was cleared when end tag was processed + self.assertEqual(result2.reasoning_text, "") + + def test_multiple_partial_fragments(self): + """ + Test handling of multiple partial fragments that don't match any tokens. + """ + detector = BaseReasoningFormatDetector("", "") + + # Send multiple partial fragments + result1 = detector.parse_streaming_increment("<") + self.assertEqual(result1.normal_text, "") + self.assertEqual(result1.reasoning_text, "") + + result2 = detector.parse_streaming_increment("/") + self.assertEqual(result2.normal_text, "") + self.assertEqual(result2.reasoning_text, "") + + result3 = detector.parse_streaming_increment("random>") + self.assertEqual(result3.normal_text, "") + self.assertEqual(result3.reasoning_text, "") + + def test_edge_case_exact_token_match(self): + """ + Test edge case where buffer content exactly matches a token. + """ + detector = BaseReasoningFormatDetector("", "") + + # Build up the exact start token character by character + detector.parse_streaming_increment("<") + detector.parse_streaming_increment("t") + detector.parse_streaming_increment("h") + detector.parse_streaming_increment("i") + detector.parse_streaming_increment("n") + result = detector.parse_streaming_increment("k>") + + # Should enter reasoning mode + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + self.assertTrue(detector._in_reasoning) + self.assertTrue(detector.stripped_think_start) + + if __name__ == "__main__": unittest.main()