model: support Step3V (#8583)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: nnnobody-code <nnnobody@foxmail.com> Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: Qiaolin-Yu <qy254@cornell.edu> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com> Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -148,7 +148,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
||||
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
|
||||
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
|
||||
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
|
||||
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
|
||||
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None |
|
||||
|
||||
## Data parallelism
|
||||
|
||||
|
||||
@@ -5,6 +5,11 @@ from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
@@ -14,4 +19,7 @@ __all__ = [
|
||||
"MultiModalityConfig",
|
||||
"KimiVLConfig",
|
||||
"MoonViTConfig",
|
||||
"Step3VLConfig",
|
||||
"Step3TextConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
]
|
||||
|
||||
@@ -335,6 +335,8 @@ class ModelConfig:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
# For Step3
|
||||
"num_attention_groups",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||
@@ -644,6 +646,7 @@ multimodal_model_archs = [
|
||||
"InternS1ForConditionalGeneration",
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
|
||||
172
python/sglang/srt/configs/step3_vl.py
Normal file
172
python/sglang/srt/configs/step3_vl.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Step3VisionEncoderConfig(PretrainedConfig):
|
||||
model_type = "step3_vision_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1792,
|
||||
intermediate_size=3072,
|
||||
output_hidden_size=4096,
|
||||
num_hidden_layers=63,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=728,
|
||||
patch_size=14,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.output_hidden_size = output_hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3TextConfig(PretrainedConfig):
|
||||
model_type = "step3_text"
|
||||
architectures = ["Step3TextForCausalLM"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 7168,
|
||||
intermediate_size: int = 18432,
|
||||
num_attention_heads: int = 64,
|
||||
num_attention_groups: int = 1,
|
||||
num_hidden_layers: int = 61,
|
||||
max_seq_len: int = 65536,
|
||||
vocab_size: int = 128815,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
moe_intermediate_size: int = 5120,
|
||||
moe_num_experts: int = 48,
|
||||
moe_top_k: int = 3,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embedding: int = 65536,
|
||||
share_expert_dim: int = 5120,
|
||||
share_q_dim: int = 2048,
|
||||
head_dim: int = 256,
|
||||
norm_expert_weight: bool = False,
|
||||
moe_layers_enum: tuple[int] = (
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_attention_groups = num_attention_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position_embedding = max_position_embedding
|
||||
self.share_expert_dim = share_expert_dim
|
||||
self.share_q_dim = share_q_dim
|
||||
self.head_dim = head_dim
|
||||
self.norm_expert_weight = norm_expert_weight
|
||||
self.moe_layers_enum = moe_layers_enum
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3VLConfig(PretrainedConfig):
|
||||
model_type = "step3_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
|
||||
text_config: Optional[Union[dict, Step3TextConfig]] = None,
|
||||
understand_projector_stride: int = 1,
|
||||
projector_bias: bool = True,
|
||||
image_token_id: int = 128001,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if vision_config is None:
|
||||
vision_config = Step3VisionEncoderConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = Step3VisionEncoderConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = Step3TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Step3TextConfig(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.understand_projector_stride = understand_projector_stride
|
||||
self.projector_bias = projector_bias
|
||||
self.hidden_size = text_config.hidden_size
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -994,6 +994,23 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="step3-vl",
|
||||
system_message="<|begin▁of▁sentence|>You are a helpful assistant",
|
||||
system_template="{system_message}\n",
|
||||
roles=(
|
||||
"<|BOT|>user\n",
|
||||
"<|BOT|>assistant\n<think>\n",
|
||||
),
|
||||
sep="<|EOT|>",
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
stop_str="<|EOT|>",
|
||||
image_token="<im_patch>",
|
||||
# add_bos=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_internvl(model_path: str):
|
||||
@@ -1103,3 +1120,9 @@ def match_vila(model_path: str):
|
||||
def match_mimo_vl(model_path: str):
|
||||
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
||||
return "mimo-vl"
|
||||
|
||||
|
||||
# @register_conv_template_matching_function
|
||||
# def match_step3(model_path: str):
|
||||
# if re.search(r"step3", model_path, re.IGNORECASE):
|
||||
# return "step3-vl"
|
||||
|
||||
@@ -17,6 +17,7 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +40,7 @@ class FunctionCallParser:
|
||||
"kimi_k2": KimiK2Detector,
|
||||
"qwen3_coder": Qwen3CoderDetector,
|
||||
"glm45": Glm4MoeDetector,
|
||||
"step3": Step3Detector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
|
||||
436
python/sglang/srt/function_call/step3_detector.py
Normal file
436
python/sglang/srt/function_call/step3_detector.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
ToolCallItem,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_argument_type(func_name: str, arg_key: str, defined_tools: List[Tool]) -> str:
|
||||
"""Get the expected type for a function argument from tool schema."""
|
||||
name2tool = {tool.function.name: tool for tool in defined_tools}
|
||||
if func_name not in name2tool:
|
||||
return None
|
||||
tool = name2tool[func_name]
|
||||
parameters = tool.function.parameters or {}
|
||||
properties = parameters.get("properties", {})
|
||||
if arg_key not in properties:
|
||||
return None
|
||||
return properties[arg_key].get("type", None)
|
||||
|
||||
|
||||
def parse_arguments(value: str) -> tuple[Any, bool]:
|
||||
"""Parse a string value to appropriate type. Returns (parsed_value, success)."""
|
||||
try:
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except:
|
||||
parsed_value = ast.literal_eval(value)
|
||||
return parsed_value, True
|
||||
except:
|
||||
return value, False
|
||||
|
||||
|
||||
class Step3Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Step3 model function call format.
|
||||
|
||||
The Step3 format uses special Unicode tokens to delimit function calls
|
||||
with steptml XML format for invocations.
|
||||
|
||||
Format Structure:
|
||||
```
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="function_name">
|
||||
<steptml:parameter name="param1">value1</steptml:parameter>
|
||||
<steptml:parameter name="param2">value2</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<|tool_calls_begin|>"
|
||||
self.eot_token = "<|tool_calls_end|>"
|
||||
self.tool_call_begin = "<|tool_call_begin|>"
|
||||
self.tool_call_end = "<|tool_call_end|>"
|
||||
self.tool_sep = "<|tool_sep|>"
|
||||
|
||||
# Regex for parsing steptml invocations
|
||||
self.invoke_regex = re.compile(
|
||||
r'<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>', re.DOTALL
|
||||
)
|
||||
self.param_regex = re.compile(
|
||||
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', re.DOTALL
|
||||
)
|
||||
|
||||
# Streaming state variables
|
||||
self._in_tool_block: bool = False
|
||||
self._tool_block_finished: bool = False
|
||||
self._current_function_name: str = ""
|
||||
self._current_parameters: Dict[str, Any] = {}
|
||||
self._in_tool_call: bool = False
|
||||
self._function_name_sent: bool = False
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Step3 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _parse_steptml_invoke(
|
||||
self, text: str, tools: List[Tool] = None
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse steptml invoke format to extract function name and parameters."""
|
||||
invoke_match = self.invoke_regex.search(text)
|
||||
if not invoke_match:
|
||||
return None, {}
|
||||
|
||||
func_name = invoke_match.group(1)
|
||||
params_text = invoke_match.group(2)
|
||||
|
||||
params = {}
|
||||
for param_match in self.param_regex.finditer(params_text):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
|
||||
# If tools provided, use schema-aware parsing
|
||||
if tools:
|
||||
arg_type = get_argument_type(func_name, param_name, tools)
|
||||
if arg_type and arg_type != "string":
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
params[param_name] = parsed_value
|
||||
else:
|
||||
params[param_name] = param_value
|
||||
else:
|
||||
# Fallback to generic parsing if no tools provided
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
params[param_name] = parsed_value
|
||||
|
||||
return func_name, params
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
"""
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
try:
|
||||
pre_text, rest = text.split(self.bot_token, 1)
|
||||
|
||||
# If no end token, return everything as normal text
|
||||
if self.eot_token not in rest:
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
tool_section, post_text = rest.split(self.eot_token, 1)
|
||||
|
||||
# Find all individual tool calls using regex
|
||||
calls = []
|
||||
tool_call_pattern = (
|
||||
f"{re.escape(self.tool_call_begin)}(.*?){re.escape(self.tool_call_end)}"
|
||||
)
|
||||
|
||||
for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL):
|
||||
call_content = match.group(1)
|
||||
|
||||
# Check if it's a function call
|
||||
if self.tool_sep not in call_content:
|
||||
continue
|
||||
|
||||
type_part, invoke_part = call_content.split(self.tool_sep, 1)
|
||||
if type_part.strip() != "function":
|
||||
continue
|
||||
|
||||
func_name, params = self._parse_steptml_invoke(invoke_part, tools)
|
||||
if func_name:
|
||||
# Use parse_base_json to create the ToolCallItem
|
||||
action = {"name": func_name, "arguments": params}
|
||||
calls.extend(self.parse_base_json(action, tools))
|
||||
|
||||
# Combine pre and post text
|
||||
normal_text = pre_text + post_text
|
||||
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect_and_parse: {e}")
|
||||
# Return the original text if parsing fails
|
||||
return StreamingParseResult(normal_text=text)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for Step3 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
|
||||
# Build tool indices for validation
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices = self._get_tool_indices(tools)
|
||||
|
||||
# If we've finished the tool block, everything is normal text
|
||||
if self._tool_block_finished:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# Check if tool block hasn't started yet
|
||||
if not self._in_tool_block:
|
||||
if self.bot_token in self._buffer:
|
||||
idx = self._buffer.find(self.bot_token)
|
||||
normal_text = self._buffer[:idx]
|
||||
self._buffer = self._buffer[idx + len(self.bot_token) :]
|
||||
self._in_tool_block = True
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
else:
|
||||
# Check if we might have a partial bot_token
|
||||
partial_len = self._ends_with_partial_token(
|
||||
self._buffer, self.bot_token
|
||||
)
|
||||
if partial_len:
|
||||
return StreamingParseResult() # Wait for more text
|
||||
else:
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
|
||||
# We're inside the tool block
|
||||
calls: List[ToolCallItem] = []
|
||||
|
||||
# Check if tool block is ending
|
||||
if self.eot_token in self._buffer:
|
||||
idx = self._buffer.find(self.eot_token)
|
||||
|
||||
# If we're in the middle of a tool call, we need to handle it
|
||||
if self._in_tool_call:
|
||||
# The buffer before eot_token might contain the end of the current tool call
|
||||
before_eot = self._buffer[:idx]
|
||||
if self.tool_call_end in before_eot:
|
||||
# Parse this final tool call
|
||||
result = self._parse_partial_tool_call(tools)
|
||||
calls.extend(result.calls)
|
||||
else:
|
||||
# Incomplete tool call - log warning
|
||||
logger.warning("Tool block ended with incomplete tool call")
|
||||
|
||||
remaining = self._buffer[idx + len(self.eot_token) :]
|
||||
self._buffer = ""
|
||||
self._tool_block_finished = True
|
||||
|
||||
# Reset any partial tool call state
|
||||
self._reset_streaming_state()
|
||||
|
||||
return StreamingParseResult(normal_text=remaining, calls=calls)
|
||||
|
||||
# Check if we're in a tool call or need to start one
|
||||
if not self._in_tool_call:
|
||||
if self.tool_call_begin in self._buffer:
|
||||
idx = self._buffer.find(self.tool_call_begin)
|
||||
# Remove any content before tool call begin (shouldn't happen but be safe)
|
||||
self._buffer = self._buffer[idx + len(self.tool_call_begin) :]
|
||||
self._in_tool_call = True
|
||||
self._function_name_sent = False
|
||||
self._current_function_name = ""
|
||||
self._current_parameters = {}
|
||||
# Fall through to parse the partial tool call
|
||||
else:
|
||||
# Wait for tool call to begin
|
||||
return StreamingParseResult()
|
||||
|
||||
# Parse partial tool call
|
||||
if self._in_tool_call:
|
||||
return self._parse_partial_tool_call(tools)
|
||||
|
||||
return StreamingParseResult()
|
||||
|
||||
def _parse_partial_tool_call(self, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""Parse partial tool call for streaming scenarios."""
|
||||
calls = []
|
||||
|
||||
# Check if we have tool_sep (means we're past the type declaration)
|
||||
if self.tool_sep not in self._buffer:
|
||||
return StreamingParseResult(calls=calls) # Wait for more text
|
||||
|
||||
type_part, invoke_part = self._buffer.split(self.tool_sep, 1)
|
||||
if type_part.strip() != "function":
|
||||
# Invalid tool type, skip this tool call
|
||||
self._reset_streaming_state()
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
# Try to extract function name if not sent yet
|
||||
if not self._function_name_sent:
|
||||
name_match = re.search(r'<steptml:invoke name="([^"]+)">', invoke_part)
|
||||
if name_match:
|
||||
func_name = name_match.group(1)
|
||||
|
||||
# Validate function name
|
||||
if func_name in self._tool_indices:
|
||||
self._current_function_name = func_name
|
||||
self._function_name_sent = True
|
||||
|
||||
# Initialize tool tracking
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
|
||||
# Ensure tracking arrays are large enough
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Store tool call info
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
|
||||
# Send tool name with empty parameters
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=func_name,
|
||||
parameters="",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Invalid function name
|
||||
logger.warning(f"Invalid function name: {func_name}")
|
||||
self._reset_streaming_state()
|
||||
return StreamingParseResult(calls=calls)
|
||||
else:
|
||||
# Function name not complete yet
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
# Parse parameters incrementally
|
||||
if self._function_name_sent:
|
||||
# Extract all complete parameters
|
||||
new_params = {}
|
||||
for param_match in self.param_regex.finditer(invoke_part):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
|
||||
# Use schema-aware parsing
|
||||
arg_type = get_argument_type(
|
||||
self._current_function_name, param_name, tools
|
||||
)
|
||||
if arg_type and arg_type != "string":
|
||||
parsed_value, _ = parse_arguments(param_value)
|
||||
new_params[param_name] = parsed_value
|
||||
else:
|
||||
new_params[param_name] = param_value
|
||||
|
||||
# Check if we have new parameters to stream
|
||||
if new_params != self._current_parameters:
|
||||
# Build the JSON content without the closing brace for streaming
|
||||
if not self._current_parameters:
|
||||
# First parameters - send opening brace and content
|
||||
params_content = json.dumps(new_params, ensure_ascii=False)
|
||||
if len(params_content) > 2: # More than just "{}"
|
||||
# Send everything except the closing brace
|
||||
diff = params_content[:-1]
|
||||
else:
|
||||
diff = "{"
|
||||
else:
|
||||
# Subsequent parameters - calculate the incremental diff
|
||||
old_json = json.dumps(self._current_parameters, ensure_ascii=False)
|
||||
new_json = json.dumps(new_params, ensure_ascii=False)
|
||||
|
||||
# Remove closing braces for comparison
|
||||
old_without_brace = old_json[:-1]
|
||||
new_without_brace = new_json[:-1]
|
||||
|
||||
# The new content should extend the old content
|
||||
if new_without_brace.startswith(old_without_brace):
|
||||
diff = new_without_brace[len(old_without_brace) :]
|
||||
else:
|
||||
# Parameters changed in unexpected way - shouldn't happen in normal streaming
|
||||
diff = ""
|
||||
|
||||
if diff:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters=diff,
|
||||
)
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
|
||||
# Update current state
|
||||
self._current_parameters = new_params
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
|
||||
|
||||
# Check if tool call is complete
|
||||
if self.tool_call_end in self._buffer:
|
||||
# Send closing brace if we've sent any parameters
|
||||
if self.streamed_args_for_tool[self.current_tool_id]:
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
parameters="}",
|
||||
)
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += "}"
|
||||
|
||||
# Find the end position
|
||||
end_idx = self._buffer.find(self.tool_call_end)
|
||||
# Remove the processed tool call from buffer
|
||||
self._buffer = self._buffer[end_idx + len(self.tool_call_end) :]
|
||||
|
||||
# Reset state for next tool call
|
||||
self._reset_streaming_state()
|
||||
self.current_tool_id += 1
|
||||
|
||||
return StreamingParseResult(calls=calls)
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset streaming state for the next tool call"""
|
||||
self._in_tool_call = False
|
||||
self._function_name_sent = False
|
||||
self._current_function_name = ""
|
||||
self._current_parameters = {}
|
||||
|
||||
def supports_structural_tag(self) -> bool:
|
||||
"""Return True if this detector supports structural tag format."""
|
||||
return False
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build EBNF grammar for Step3 tool call format.
|
||||
"""
|
||||
# Custom call rule for steptml format
|
||||
call_rule_fmt = (
|
||||
'"function" "<|tool_sep|>" "<steptml:invoke name=\\"{name}\\">" '
|
||||
'{arguments_rule} "</steptml:invoke>"'
|
||||
)
|
||||
|
||||
# Custom key-value rule for steptml parameters
|
||||
key_value_rule_fmt = (
|
||||
'"<steptml:parameter name=\\"{key}\\">" {valrule} "</steptml:parameter>"'
|
||||
)
|
||||
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
individual_call_start_token=self.tool_call_begin,
|
||||
individual_call_end_token=self.tool_call_end,
|
||||
tool_call_separator="",
|
||||
function_format="xml",
|
||||
call_rule_fmt=call_rule_fmt,
|
||||
key_value_rule_fmt=key_value_rule_fmt,
|
||||
key_value_separator="",
|
||||
)
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.configs import (
|
||||
ExaoneConfig,
|
||||
KimiVLConfig,
|
||||
MultiModalityConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
from sglang.srt.configs.internvl import InternVLChatConfig
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
@@ -54,6 +55,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
KimiVLConfig.model_type: KimiVLConfig,
|
||||
InternVLChatConfig.model_type: InternVLChatConfig,
|
||||
Step3VLConfig.model_type: Step3VLConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
@@ -165,7 +165,7 @@ def process_content_for_template_format(
|
||||
new_msg["content"] = processed_content_parts
|
||||
return new_msg
|
||||
|
||||
else: # content_format == "string"
|
||||
elif content_format == "string":
|
||||
# String format: flatten to text only (for templates like DeepSeek)
|
||||
text_parts = []
|
||||
for chunk in msg_dict["content"]:
|
||||
@@ -179,3 +179,6 @@ def process_content_for_template_format(
|
||||
new_msg["content"] = " ".join(text_parts) if text_parts else ""
|
||||
new_msg = {k: v for k, v in new_msg.items() if v is not None}
|
||||
return new_msg
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid content format: {content_format}")
|
||||
|
||||
@@ -53,7 +53,7 @@ class TemplateManager:
|
||||
def __init__(self):
|
||||
self._chat_template_name: Optional[str] = None
|
||||
self._completion_template_name: Optional[str] = None
|
||||
self._jinja_template_content_format: Optional[str] = None
|
||||
self._jinja_template_content_format: Optional[str] = "openai"
|
||||
|
||||
@property
|
||||
def chat_template_name(self) -> Optional[str]:
|
||||
@@ -71,31 +71,60 @@ class TemplateManager:
|
||||
return self._jinja_template_content_format
|
||||
|
||||
def load_chat_template(
|
||||
self, tokenizer_manager, chat_template_arg: str, model_path: str
|
||||
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Load a chat template from various sources.
|
||||
|
||||
Args:
|
||||
tokenizer_manager: The tokenizer manager instance
|
||||
chat_template_arg: Template name or file path
|
||||
chat_template_arg: Template name, file path, or None to auto-detect
|
||||
model_path: Path to the model
|
||||
"""
|
||||
logger.info(f"Loading chat template: {chat_template_arg}")
|
||||
if chat_template_arg:
|
||||
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
|
||||
else:
|
||||
# Try HuggingFace template first
|
||||
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
||||
if hf_template:
|
||||
self._jinja_template_content_format = (
|
||||
detect_jinja_template_content_format(hf_template)
|
||||
)
|
||||
logger.info(
|
||||
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
||||
)
|
||||
return
|
||||
|
||||
if not chat_template_exists(chat_template_arg):
|
||||
if not os.path.exists(chat_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
# Fallback to SGLang template guessing
|
||||
self.guess_chat_template_from_model_path(model_path)
|
||||
|
||||
# Set default format if no template was found
|
||||
if self._chat_template_name is None:
|
||||
self._jinja_template_content_format = "string"
|
||||
logger.info(
|
||||
"No chat template found, defaulting to 'string' content format"
|
||||
)
|
||||
|
||||
if chat_template_arg.endswith(".jinja"):
|
||||
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
||||
else:
|
||||
self._load_json_chat_template(chat_template_arg)
|
||||
else:
|
||||
def _load_explicit_chat_template(
|
||||
self, tokenizer_manager, chat_template_arg: str
|
||||
) -> None:
|
||||
"""Load explicitly specified chat template."""
|
||||
logger.info(f"Loading chat template from argument: {chat_template_arg}")
|
||||
|
||||
if chat_template_exists(chat_template_arg):
|
||||
self._chat_template_name = chat_template_arg
|
||||
return
|
||||
|
||||
if not os.path.exists(chat_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
)
|
||||
|
||||
if chat_template_arg.endswith(".jinja"):
|
||||
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
||||
else:
|
||||
self._load_json_chat_template(chat_template_arg)
|
||||
|
||||
def guess_chat_template_from_model_path(self, model_path: str) -> None:
|
||||
"""
|
||||
@@ -146,10 +175,7 @@ class TemplateManager:
|
||||
completion_template: Optional completion template name/path
|
||||
"""
|
||||
# Load chat template
|
||||
if chat_template:
|
||||
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
||||
else:
|
||||
self.guess_chat_template_from_model_path(model_path)
|
||||
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
||||
|
||||
# Load completion template
|
||||
if completion_template:
|
||||
@@ -166,7 +192,7 @@ class TemplateManager:
|
||||
chat_template
|
||||
)
|
||||
logger.info(
|
||||
f"Detected chat template content format: {self._jinja_template_content_format}"
|
||||
f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}"
|
||||
)
|
||||
|
||||
def _load_json_chat_template(self, template_path: str) -> None:
|
||||
@@ -224,3 +250,20 @@ class TemplateManager:
|
||||
override=True,
|
||||
)
|
||||
self._completion_template_name = template["name"]
|
||||
|
||||
def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]:
|
||||
"""
|
||||
Resolve HuggingFace chat template.
|
||||
|
||||
Returns the chat template string if found, None otherwise.
|
||||
"""
|
||||
tokenizer = tokenizer_manager.tokenizer
|
||||
|
||||
# Try to get AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting chat template via get_chat_template(): {e}")
|
||||
|
||||
logger.debug("No HuggingFace chat template found")
|
||||
return None
|
||||
|
||||
994
python/sglang/srt/models/step3_vl.py
Normal file
994
python/sglang/srt/models/step3_vl.py
Normal file
@@ -0,0 +1,994 @@
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from math import sqrt
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torch.nn import functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
Text Model
|
||||
"""
|
||||
|
||||
|
||||
class Step3TextMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Step3TextMoEMLP(nn.Module):
|
||||
# Native
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: Step3TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layer_id = layer_id
|
||||
if self.tp_size > config.moe_num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.moe_num_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.moe_top_k,
|
||||
renormalize=config.norm_expert_weight,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
output_size=config.moe_num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Step3TextAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
share_q_dim: int,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rms_norm_eps=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.all_tp_rank = get_tensor_model_parallel_rank()
|
||||
self.total_num_heads = num_heads
|
||||
self.attn_tp_rank = attn_tp_rank
|
||||
self.layer_id = layer_id
|
||||
assert self.total_num_heads % attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // attn_tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||
self.head_dim = head_dim
|
||||
self.q_size = share_q_dim if share_q_dim else head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[self.q_size, self.kv_size, self.kv_size],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=0, # In fact, we need a MergedReplicatedLinear
|
||||
tp_size=1,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps)
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
self.q_size,
|
||||
self.head_dim * self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("wq", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.inter_norm(q.contiguous())
|
||||
q, _ = self.wq(q)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Step3TextDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Step3TextConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
# TODO: support shared experts fusion
|
||||
# self.n_shared_experts = 1
|
||||
# self.num_fused_shared_experts = (
|
||||
# 0
|
||||
# if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
# else self.n_shared_experts
|
||||
# )
|
||||
self.num_fused_shared_experts = 0
|
||||
rms_norm_eps = config.rms_norm_eps
|
||||
self.self_attn = Step3TextAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=1,
|
||||
head_dim=head_dim,
|
||||
share_q_dim=config.share_q_dim,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
||||
if moe_layers_enum is not None:
|
||||
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
|
||||
else:
|
||||
# Default to 1dense.
|
||||
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
||||
|
||||
self.use_moe = False
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.is_layer_sparse = True if layer_id in moe_layers_idx else False
|
||||
self.is_previous_layer_sparse = (
|
||||
True if layer_id - 1 in moe_layers_idx else False
|
||||
)
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=self.is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
if not self.is_layer_sparse:
|
||||
self.mlp = Step3TextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.use_moe = True
|
||||
if self.num_fused_shared_experts == 0:
|
||||
self.moe = Step3TextMoEMLP(
|
||||
layer_id=layer_id,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.share_expert = Step3TextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.share_expert_dim,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("share_expert", prefix),
|
||||
)
|
||||
else:
|
||||
self.moe = Step3TextMoEMLP(
|
||||
layer_id=layer_id,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
|
||||
def moe_mlp_forward(self, hidden_states):
|
||||
if not self.num_fused_shared_experts:
|
||||
h = hidden_states.clone()
|
||||
hidden_states = self.moe(hidden_states)
|
||||
hidden_states += self.share_expert(h)
|
||||
else:
|
||||
hidden_states = self.moe(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
if self.use_moe:
|
||||
hidden_states = self.moe_mlp_forward(hidden_states)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Step3TextModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Step3TextDecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] != 0:
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
"""
|
||||
Vision Model
|
||||
"""
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
dim = abs_pos.size(-1)
|
||||
abs_pos_new = abs_pos.squeeze(0)
|
||||
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
|
||||
|
||||
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
old_pos_embed = (
|
||||
old_pos_embed.view(1, src_size, src_size, dim)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
old_pos_embed = old_pos_embed.to(torch.float32)
|
||||
new_pos_embed = F.interpolate(
|
||||
old_pos_embed,
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
align_corners=False,
|
||||
).to(dtype)
|
||||
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
||||
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
|
||||
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
|
||||
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
|
||||
return vision_pos_embed
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
class Step3VisionMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_size: int,
|
||||
bias: bool = True,
|
||||
hidden_act="quick_gelu",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
dim,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_proj", prefix),
|
||||
)
|
||||
self.act = ACT2FN[hidden_act] # quick_gelu
|
||||
self.fc2 = RowParallelLinear(
|
||||
intermediate_size,
|
||||
dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 16,
|
||||
qkv_backend="fa3",
|
||||
quant_config=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.out_proj = RowParallelLinear(
|
||||
dim,
|
||||
dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("out_proj", prefix),
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.attn = VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=True,
|
||||
rotary_embed="normal",
|
||||
proj_bias=True,
|
||||
qkv_backend=qkv_backend,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_output = self.attn(hidden_states)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Step3VisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: Step3VisionEncoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.pad_tp_size = 4 # hard code for padding
|
||||
# To load the pretrained weights, we still use P+1 as the seqlen
|
||||
self.position_embedding = torch.nn.Embedding(
|
||||
self.num_patches + 1, self.embed_dim
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_patches + 1).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values
|
||||
) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
# pad
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + get_abs_pos(
|
||||
self.position_embedding(self.position_ids), patch_embeds.size(1)
|
||||
)
|
||||
embeddings = torch.cat(
|
||||
[
|
||||
embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
|
||||
embeddings,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Step3VisionEncoderLayer(nn.Module):
|
||||
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6)
|
||||
self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6)
|
||||
|
||||
self.self_attn = Step3VisionAttention(
|
||||
self.embed_dim, num_heads=config.num_attention_heads
|
||||
)
|
||||
self.mlp = Step3VisionMLP(
|
||||
dim=self.embed_dim,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
|
||||
hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3VisionTransformer(nn.Module):
|
||||
def __init__(self, config: Step3VisionEncoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_size = config.image_size
|
||||
self.embeddings = Step3VisionEmbeddings(config)
|
||||
self.transformer = Step3VisionEncoder(config)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.embeddings.patch_embedding.weight.dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
):
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3VisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`Step3VisionEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: StepVisionEncoderConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: Step3VisionEncoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Step3VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Step3TextModel(
|
||||
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
self.vision_model = Step3VisionTransformer(config.vision_config)
|
||||
|
||||
self.vit_downsampler = nn.Conv2d(
|
||||
config.vision_config.hidden_size,
|
||||
config.vision_config.output_hidden_size,
|
||||
kernel_size=2,
|
||||
stride=config.understand_projector_stride,
|
||||
)
|
||||
self.vit_downsampler2 = nn.Conv2d(
|
||||
config.vision_config.output_hidden_size,
|
||||
config.vision_config.output_hidden_size * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
)
|
||||
self.vit_large_projector = nn.Linear(
|
||||
config.vision_config.output_hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=config.projector_bias,
|
||||
)
|
||||
|
||||
# TODO: support shared experts fusion
|
||||
# self.n_shared_experts = 1
|
||||
# self.num_fused_shared_experts = (
|
||||
# 0
|
||||
# if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
# else self.n_shared_experts
|
||||
# )
|
||||
self.num_fused_shared_experts = 0
|
||||
self.config.tie_word_embeddings = False
|
||||
if getattr(self.config, "tie_word_embeddings", False):
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.text_config.vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return self.vision_model(input_tensor)[:, 4:]
|
||||
|
||||
def _flatten_embeddings(self, embeddings) -> torch.Tensor:
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
# Flatten all but the last dimension.
|
||||
return embeddings.flatten(0, -2)
|
||||
|
||||
return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings))
|
||||
|
||||
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
B, P = image_features.shape[:2]
|
||||
HW = int(sqrt(P))
|
||||
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
|
||||
image_features = self.vit_downsampler(image_features)
|
||||
image_features = self.vit_downsampler2(image_features)
|
||||
n_dim = image_features.size(1)
|
||||
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
|
||||
image_features = self.vit_large_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
assert len(items) == 1 # We only have images.
|
||||
|
||||
item = items[0]
|
||||
pixel_values = item.feature.type(self.vision_model.dtype)
|
||||
num_patches = item.model_specific_data.get("num_patches")
|
||||
patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None)
|
||||
if patch_pixel_values is not None:
|
||||
patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype)
|
||||
|
||||
if patch_pixel_values is not None:
|
||||
patch_pixel_values = patch_pixel_values.to("cuda")
|
||||
|
||||
image_features = self._get_vision_model_output(pixel_values)
|
||||
patch_image_features = (
|
||||
self._get_vision_model_output(patch_pixel_values)
|
||||
if patch_pixel_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
image_features = self._process_image_features(image_features)
|
||||
patch_image_features = (
|
||||
self._process_image_features(patch_image_features)
|
||||
if patch_image_features is not None
|
||||
else None
|
||||
)
|
||||
|
||||
merged_image_features = []
|
||||
cur_patch_idx = 0
|
||||
for i, num_patch in enumerate(num_patches):
|
||||
cur_feature = []
|
||||
if num_patch > 0:
|
||||
patch_slice = patch_image_features[
|
||||
cur_patch_idx : cur_patch_idx + num_patch
|
||||
]
|
||||
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
|
||||
cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
|
||||
cur_patch_idx += num_patch
|
||||
merged_image_features.append(
|
||||
torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
|
||||
)
|
||||
return self._flatten_embeddings(merged_image_features)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.model,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# TODO:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", 0),
|
||||
(".qkv_proj", ".k_proj", 1),
|
||||
(".qkv_proj", ".v_proj", 2),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
if self.num_fused_shared_experts > 0:
|
||||
assert self.num_fused_shared_experts == 1
|
||||
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.text_config.moe_num_experts
|
||||
+ self.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
|
||||
def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool:
|
||||
name_parts = name_path.split(".")
|
||||
weight_parts = weight_path.split(".")
|
||||
shard_id_matches = name_parts[4] == weight_parts[2]
|
||||
return shard_id_matches
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "vision_model" in name:
|
||||
# 1.It’s not great, but let’s leave it like this for now
|
||||
name = name.replace("self_attn", "self_attn.attn")
|
||||
# 2.
|
||||
name = name.replace("out_proj", "proj")
|
||||
|
||||
# TODO: support vision model
|
||||
if self.num_fused_shared_experts > 0 and "share" in name:
|
||||
# assert False
|
||||
name = name.replace("share_expert", "moe")
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if (
|
||||
expert_id != self.config.text_config.moe_num_experts
|
||||
or not match_expert_and_shard_ids(name, weight_name)
|
||||
):
|
||||
continue
|
||||
|
||||
part_name = weight_name.split(".")[-2]
|
||||
fake_weight_name = name.replace(part_name, weight_name[:-1])
|
||||
actual_param_name = name.replace(part_name + ".", param_name)
|
||||
param = params_dict[actual_param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "gate." not in name and "moe" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if "moe" not in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
else:
|
||||
if "gate." in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if expert_id == self.config.text_config.moe_num_experts:
|
||||
continue
|
||||
if not match_expert_and_shard_ids(name, weight_name):
|
||||
continue
|
||||
part_name = weight_name.split(".")[-2]
|
||||
fake_weight_name = name.replace(part_name, weight_name[:-1])
|
||||
actual_param_name = name.replace(part_name + ".", param_name)
|
||||
param = params_dict[actual_param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight[expert_id],
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(actual_param_name)
|
||||
# Don't break here, because this 'loaded_weight' includes all the weights for this layer
|
||||
|
||||
@classmethod
|
||||
def get_model_config_for_expert_location(cls, config: Step3VLConfig):
|
||||
return ModelConfigForExpertLocation(
|
||||
num_layers=config.text_config.num_hidden_layers,
|
||||
num_logical_experts=config.text_config.moe_num_experts,
|
||||
num_groups=None,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = Step3VLForConditionalGeneration
|
||||
@@ -176,6 +176,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
"image_grid_hws": Modality.IMAGE,
|
||||
"aspect_ratio_ids": Modality.IMAGE,
|
||||
"aspect_ratio_mask": Modality.IMAGE,
|
||||
"num_patches": Modality.IMAGE,
|
||||
"patch_pixel_values": Modality.IMAGE,
|
||||
# Audio-related attributes
|
||||
"audio_features": Modality.AUDIO,
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
|
||||
515
python/sglang/srt/multimodal/processors/step3_vl.py
Normal file
515
python/sglang/srt/multimodal/processors/step3_vl.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import math
|
||||
import re
|
||||
from itertools import product
|
||||
from typing import List, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, TensorType
|
||||
|
||||
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
|
||||
|
||||
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
|
||||
|
||||
|
||||
class GPUToTensor(torch.nn.Module):
|
||||
|
||||
def forward(self, raw_image: Union[np.ndarray, Image.Image]) -> torch.Tensor:
|
||||
if isinstance(raw_image, Image.Image):
|
||||
return transforms.ToTensor()(raw_image)
|
||||
if raw_image.ndim == 2:
|
||||
raw_image = raw_image[:, :, None].repeat(3, -1)
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
image_tensor = torch.from_numpy(raw_image).to(device)
|
||||
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
|
||||
if image_tensor.dtype == torch.uint8:
|
||||
image_tensor = image_tensor.to(torch.float32).div(255)
|
||||
return image_tensor
|
||||
|
||||
|
||||
class Step3VisionProcessor:
|
||||
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
patch_size = patch_size if patch_size is not None else size
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
GPUToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
transforms.Resize(
|
||||
(size, size),
|
||||
interpolation=(
|
||||
InterpolationMode.BICUBIC
|
||||
if interpolation_mode == "bicubic"
|
||||
else InterpolationMode.BILINEAR
|
||||
),
|
||||
antialias=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.patch_transform = (
|
||||
transforms.Compose(
|
||||
[
|
||||
GPUToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
transforms.Resize(
|
||||
(patch_size, patch_size),
|
||||
interpolation=(
|
||||
InterpolationMode.BICUBIC
|
||||
if interpolation_mode == "bicubic"
|
||||
else InterpolationMode.BILINEAR
|
||||
),
|
||||
antialias=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
if patch_size is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def __call__(self, image, is_patch=False):
|
||||
if is_patch:
|
||||
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
|
||||
else:
|
||||
return {"pixel_values": self.transform(image).unsqueeze(0)}
|
||||
|
||||
|
||||
class ImagePatcher:
|
||||
|
||||
def determine_window_size(self, long: int, short: int) -> int:
|
||||
if long <= 728:
|
||||
return short if long / short > 1.5 else 0
|
||||
return min(short, 504) if long / short > 4 else 504
|
||||
|
||||
def slide_window(
|
||||
self,
|
||||
width: int,
|
||||
height: int,
|
||||
sizes: list[tuple[int, int]],
|
||||
steps: list[tuple[int, int]],
|
||||
img_rate_thr: float = 0.6,
|
||||
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
|
||||
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
|
||||
windows = []
|
||||
# Sliding windows.
|
||||
for size, step in zip(sizes, steps):
|
||||
size_w, size_h = size
|
||||
step_w, step_h = step
|
||||
|
||||
x_num = 1 if width <= size_w else math.ceil((width - size_w) / step_w + 1)
|
||||
x_start = [step_w * i for i in range(x_num)]
|
||||
if len(x_start) > 1 and x_start[-1] + size_w > width:
|
||||
x_start[-1] = width - size_w
|
||||
|
||||
y_num = 1 if height <= size_h else math.ceil((height - size_h) / step_h + 1)
|
||||
y_start = [step_h * i for i in range(y_num)]
|
||||
if len(y_start) > 1 and y_start[-1] + size_h > height:
|
||||
y_start[-1] = height - size_h
|
||||
|
||||
start = np.array(list(product(y_start, x_start)), dtype=int)
|
||||
start[:, [0, 1]] = start[:, [1, 0]]
|
||||
windows.append(np.concatenate([start, start + size], axis=1))
|
||||
windows = np.concatenate(windows, axis=0)
|
||||
|
||||
return [
|
||||
(int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1]))
|
||||
for box in windows
|
||||
], (x_num, y_num)
|
||||
|
||||
def square_pad(self, img: Image.Image) -> Image.Image:
|
||||
w, h = img.size
|
||||
if w == h:
|
||||
return img
|
||||
size = max(w, h)
|
||||
padded = Image.new(img.mode, (size, size), 0)
|
||||
padded.paste(img, (0, 0))
|
||||
return padded
|
||||
|
||||
def get_image_size_for_padding(
|
||||
self, img_width: int, img_height: int
|
||||
) -> tuple[int, int]:
|
||||
ratio = img_width / img_height
|
||||
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
|
||||
new_size = max(img_height, img_width)
|
||||
return new_size, new_size
|
||||
return img_width, img_height
|
||||
|
||||
def get_image_size_for_preprocess(
|
||||
self, img_width: int, img_height: int
|
||||
) -> tuple[int, int]:
|
||||
|
||||
if max(img_height, img_width) > 3024:
|
||||
scale_factor = 3024 / max(img_height, img_width)
|
||||
img_width = int(img_width * scale_factor)
|
||||
img_height = int(img_height * scale_factor)
|
||||
return img_width, img_height
|
||||
else:
|
||||
return img_width, img_height
|
||||
|
||||
def get_image_size_for_crop(
|
||||
self, img_width: int, img_height: int, window_size: int
|
||||
):
|
||||
w_ratio = img_width / window_size
|
||||
h_ratio = img_height / window_size
|
||||
|
||||
if w_ratio < 1:
|
||||
width_new = img_width
|
||||
else:
|
||||
decimal_w = w_ratio - img_width // window_size
|
||||
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
|
||||
width_new = window_size * w_ratio
|
||||
if h_ratio < 1:
|
||||
height_new = img_height
|
||||
else:
|
||||
decimal_h = h_ratio - img_height // window_size
|
||||
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
|
||||
height_new = window_size * h_ratio
|
||||
return int(width_new), int(height_new)
|
||||
|
||||
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
|
||||
target = img.crop((j, i, j + tw, i + th))
|
||||
return target
|
||||
|
||||
def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]:
|
||||
img_width, img_height = self.get_image_size_for_padding(img_width, img_height)
|
||||
img_width, img_height = self.get_image_size_for_preprocess(
|
||||
img_width, img_height
|
||||
)
|
||||
window_size = self.determine_window_size(
|
||||
max(img_height, img_width), min(img_height, img_width)
|
||||
)
|
||||
if window_size == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
img_width, img_height = self.get_image_size_for_crop(
|
||||
img_width, img_height, window_size
|
||||
)
|
||||
center_list, (x_num, y_num) = self.slide_window(
|
||||
img_width,
|
||||
img_height,
|
||||
[(window_size, window_size)],
|
||||
[(window_size, window_size)],
|
||||
)
|
||||
full_rows = (len(center_list) - 1) // x_num + 1
|
||||
if len(center_list) > 0 and len(center_list) % x_num == 0:
|
||||
full_rows -= 1
|
||||
return len(center_list), full_rows
|
||||
|
||||
def __call__(
|
||||
self, img: Image.Image
|
||||
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
|
||||
img_width, img_height = img.size
|
||||
new_img_width, new_img_height = self.get_image_size_for_padding(
|
||||
img_width, img_height
|
||||
)
|
||||
if new_img_width != img_width or new_img_height != img_height:
|
||||
img = self.square_pad(img)
|
||||
img_width, img_height = img.size
|
||||
|
||||
new_img_width, new_img_height = self.get_image_size_for_preprocess(
|
||||
img_width, img_height
|
||||
)
|
||||
img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR)
|
||||
window_size = self.determine_window_size(
|
||||
max(new_img_height, new_img_width), min(new_img_height, new_img_width)
|
||||
)
|
||||
if window_size == 0:
|
||||
return img, [], None
|
||||
else:
|
||||
new_img_width, new_img_height = self.get_image_size_for_crop(
|
||||
new_img_width, new_img_height, window_size
|
||||
)
|
||||
if (new_img_width, new_img_height) != (img_width, img_height):
|
||||
img_for_crop = img.resize(
|
||||
(new_img_width, new_img_height), Image.Resampling.BILINEAR
|
||||
)
|
||||
else:
|
||||
img_for_crop = img
|
||||
|
||||
patches = []
|
||||
newlines = []
|
||||
center_list, (x_num, y_num) = self.slide_window(
|
||||
new_img_width,
|
||||
new_img_height,
|
||||
[(window_size, window_size)],
|
||||
[(window_size, window_size)],
|
||||
)
|
||||
for patch_id, center_lf_point in enumerate(center_list):
|
||||
x, y, patch_w, patch_h = center_lf_point
|
||||
big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w)
|
||||
patches.append(big_patch)
|
||||
if (patch_id + 1) % x_num == 0:
|
||||
newlines.append(patch_id)
|
||||
|
||||
if newlines and newlines[-1] == len(patches) - 1:
|
||||
newlines.pop()
|
||||
|
||||
return (
|
||||
img,
|
||||
patches,
|
||||
(
|
||||
[i in newlines for i in range(len(patches))]
|
||||
if len(patches) > 0
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Step3VLProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.image_size = 728
|
||||
self.patch_size = 504
|
||||
self.image_preprocessor = Step3VisionProcessor(
|
||||
self.image_size, "bilinear", self.patch_size
|
||||
)
|
||||
|
||||
self.num_image_feature_size = 169
|
||||
self.num_patch_feature_size = 81
|
||||
self.image_token = "<im_patch>"
|
||||
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
|
||||
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
|
||||
|
||||
self.patcher = ImagePatcher()
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[self.image_token]
|
||||
|
||||
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
|
||||
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
|
||||
|
||||
return (
|
||||
num_patches * (self.num_patch_feature_size + 2)
|
||||
+ self.num_image_feature_size
|
||||
+ 2
|
||||
+ num_newlines
|
||||
)
|
||||
|
||||
def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]:
|
||||
result = []
|
||||
for img in images:
|
||||
result.append(self.patcher(img))
|
||||
return result
|
||||
|
||||
def _convert_images_to_pixel_values(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
is_patch: bool = False,
|
||||
) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
|
||||
for img in images
|
||||
]
|
||||
|
||||
def _get_patch_repl(
|
||||
self,
|
||||
num_patches: int,
|
||||
patch_newline_mask: list[bool] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
text = ""
|
||||
token_ids = []
|
||||
for i in range(num_patches):
|
||||
assert len(patch_newline_mask) == num_patches
|
||||
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
|
||||
token_ids.extend(
|
||||
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
|
||||
+ [self.image_token_id] * self.num_patch_feature_size
|
||||
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
|
||||
)
|
||||
if patch_newline_mask and patch_newline_mask[i]:
|
||||
text += "<patch_newline>"
|
||||
token_ids.append(
|
||||
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
|
||||
)
|
||||
return text, token_ids
|
||||
|
||||
def _get_image_repl(
|
||||
self,
|
||||
num_images: int,
|
||||
) -> tuple[str, list[int]]:
|
||||
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
|
||||
token_ids = (
|
||||
[self.tokenizer.convert_tokens_to_ids("<im_start>")]
|
||||
+ [self.image_token_id] * self.num_image_feature_size
|
||||
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")]
|
||||
)
|
||||
return text * num_images, token_ids * num_images
|
||||
|
||||
def _get_image_repl_features(
|
||||
self,
|
||||
num_images: int,
|
||||
num_patches: int,
|
||||
patch_new_line_idx: Optional[list[bool]],
|
||||
) -> tuple[str, list[int]]:
|
||||
if num_patches > 0:
|
||||
patch_repl, patch_repl_ids = self._get_patch_repl(
|
||||
num_patches, patch_new_line_idx
|
||||
)
|
||||
else:
|
||||
patch_repl = ""
|
||||
patch_repl_ids = []
|
||||
image_repl, image_repl_ids = self._get_image_repl(num_images)
|
||||
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
|
||||
|
||||
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
|
||||
parts = text.split(placeholder)
|
||||
|
||||
if len(parts) - 1 != len(repls):
|
||||
raise ValueError(
|
||||
"The number of placeholders does not match the number of replacements." # noqa: E501
|
||||
)
|
||||
|
||||
result = [parts[0]]
|
||||
for i, repl in enumerate(repls):
|
||||
result.append(repl)
|
||||
result.append(parts[i + 1])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
text_inputs = self.tokenizer(text)
|
||||
else:
|
||||
splitted_images_data = self._split_images(images)
|
||||
pixel_values_lst = []
|
||||
patch_pixel_values_lst = []
|
||||
patch_newline_mask_lst = []
|
||||
image_repl_str_lst = []
|
||||
image_repl_ids_lst = []
|
||||
num_patches = []
|
||||
for (
|
||||
raw_img,
|
||||
img_patches,
|
||||
patch_newline_mask,
|
||||
) in splitted_images_data: # noqa: E501
|
||||
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
|
||||
|
||||
if len(img_patches) > 0:
|
||||
patch_pixel_values_lst.extend(
|
||||
self._convert_images_to_pixel_values(img_patches, is_patch=True)
|
||||
)
|
||||
num_patches.append(len(img_patches))
|
||||
|
||||
image_repl_str, image_repl_ids = self._get_image_repl_features(
|
||||
1, len(img_patches), patch_newline_mask
|
||||
)
|
||||
image_repl_str_lst.append(image_repl_str)
|
||||
image_repl_ids_lst.extend(image_repl_ids)
|
||||
|
||||
if patch_newline_mask is not None:
|
||||
patch_newline_mask_lst.extend(patch_newline_mask)
|
||||
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_lst),
|
||||
"num_patches": num_patches,
|
||||
}
|
||||
if patch_pixel_values_lst:
|
||||
image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst)
|
||||
if patch_newline_mask_lst:
|
||||
image_inputs["patch_newline_mask"] = torch.tensor(
|
||||
patch_newline_mask_lst, dtype=torch.bool
|
||||
)
|
||||
|
||||
text = [
|
||||
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
|
||||
for t in text
|
||||
]
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
################################################
|
||||
|
||||
|
||||
class Step3VLImageProcessor(SGLangBaseProcessor):
|
||||
models = [Step3VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
# TODO, check _processor is tokenizer or processor.
|
||||
processor = Step3VLProcessor(hf_config, _processor)
|
||||
super().__init__(hf_config, server_args, processor, *args, **kwargs)
|
||||
self.IM_TOKEN_ID = 128001
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<im_patch>",
|
||||
image_token_id=128001,
|
||||
image_token_regex=re.compile(r"(?:<im_patch>)"),
|
||||
).build(_processor)
|
||||
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
def preprocess(self, image):
|
||||
return {"pixel_values": self.transform(image).unsqueeze(0)}
|
||||
|
||||
def __call__(self, image):
|
||||
return self.preprocess(image)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text: str | List[int],
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
video_data=request_obj.video_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
@@ -105,7 +105,7 @@ class BaseReasoningFormatDetector:
|
||||
# If we're not in a reasoning block return as normal text
|
||||
if not self._in_reasoning:
|
||||
self._buffer = ""
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
|
||||
return StreamingParseResult()
|
||||
|
||||
@@ -233,6 +233,7 @@ class ReasoningParser:
|
||||
"qwen3-thinking": Qwen3ThinkingDetector,
|
||||
"glm45": Qwen3Detector,
|
||||
"kimi": KimiDetector,
|
||||
"step3": DeepSeekR1Detector,
|
||||
}
|
||||
|
||||
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
||||
|
||||
@@ -1117,9 +1117,10 @@ class ServerArgs:
|
||||
"kimi_k2",
|
||||
"qwen3_coder",
|
||||
"glm45",
|
||||
"step3",
|
||||
],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
|
||||
@@ -493,5 +493,117 @@ class TestIntegrationScenarios(CustomTestCase):
|
||||
self.assertIn("final answer", all_normal)
|
||||
|
||||
|
||||
class TestBufferLossBugFix(CustomTestCase):
|
||||
"""Test cases for the buffer loss bug fix in parse_streaming_increment."""
|
||||
|
||||
def test_partial_end_tag_buffer_loss_bug(self):
|
||||
"""
|
||||
Test the bug where partial end tag fragments are lost when followed by normal text.
|
||||
|
||||
Bug scenario:
|
||||
1. _in_reasoning is False
|
||||
2. new_text is "</" (part of closing thinking tag)
|
||||
3. Fragment is stored in buffer and empty string is returned
|
||||
4. Next step: new_text is "answer", _in_reasoning still False
|
||||
5. Buffer is cleared and "answer" is returned directly
|
||||
6. The "</" from previous step is lost
|
||||
|
||||
This test verifies the fix where line 108 was changed from:
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
to:
|
||||
return StreamingParseResult(normal_text=current_text)
|
||||
"""
|
||||
detector = BaseReasoningFormatDetector("<think>", "</think>")
|
||||
|
||||
# Step 1: Send partial end tag when not in reasoning mode
|
||||
# This should be buffered since it could be start of "</think>"
|
||||
result1 = detector.parse_streaming_increment("</")
|
||||
self.assertEqual(result1.normal_text, "")
|
||||
self.assertEqual(result1.reasoning_text, "")
|
||||
|
||||
# Step 2: Send normal text that doesn't complete the end tag
|
||||
# Before fix: would return only "answer", losing the "</"
|
||||
# After fix: should return the complete buffered content "</answer"
|
||||
result2 = detector.parse_streaming_increment("answer")
|
||||
self.assertEqual(result2.normal_text, "</answer")
|
||||
self.assertEqual(result2.reasoning_text, "")
|
||||
|
||||
def test_partial_start_tag_buffer_preservation(self):
|
||||
"""
|
||||
Test that partial start tag fragments are properly preserved.
|
||||
"""
|
||||
detector = BaseReasoningFormatDetector("<think>", "</think>")
|
||||
|
||||
# Send partial start tag
|
||||
result1 = detector.parse_streaming_increment("<th")
|
||||
self.assertEqual(result1.normal_text, "")
|
||||
self.assertEqual(result1.reasoning_text, "")
|
||||
|
||||
# Complete with non-matching text
|
||||
result2 = detector.parse_streaming_increment("is is text")
|
||||
self.assertEqual(result2.normal_text, "<this is text")
|
||||
self.assertEqual(result2.reasoning_text, "")
|
||||
|
||||
def test_partial_end_tag_in_reasoning_mode(self):
|
||||
"""
|
||||
Test partial end tag handling when already in reasoning mode.
|
||||
"""
|
||||
detector = BaseReasoningFormatDetector("<think>", "</think>")
|
||||
|
||||
# Enter reasoning mode
|
||||
detector.parse_streaming_increment("<think>")
|
||||
detector.parse_streaming_increment("some reasoning")
|
||||
|
||||
# Send partial end tag
|
||||
result1 = detector.parse_streaming_increment("</")
|
||||
self.assertEqual(result1.normal_text, "")
|
||||
self.assertEqual(result1.reasoning_text, "")
|
||||
|
||||
# Complete the end tag with normal text
|
||||
result2 = detector.parse_streaming_increment("think>normal text")
|
||||
self.assertEqual(result2.normal_text, "normal text")
|
||||
# The reasoning text should be empty since buffer was cleared when end tag was processed
|
||||
self.assertEqual(result2.reasoning_text, "")
|
||||
|
||||
def test_multiple_partial_fragments(self):
|
||||
"""
|
||||
Test handling of multiple partial fragments that don't match any tokens.
|
||||
"""
|
||||
detector = BaseReasoningFormatDetector("<think>", "</think>")
|
||||
|
||||
# Send multiple partial fragments
|
||||
result1 = detector.parse_streaming_increment("<")
|
||||
self.assertEqual(result1.normal_text, "")
|
||||
self.assertEqual(result1.reasoning_text, "")
|
||||
|
||||
result2 = detector.parse_streaming_increment("/")
|
||||
self.assertEqual(result2.normal_text, "")
|
||||
self.assertEqual(result2.reasoning_text, "")
|
||||
|
||||
result3 = detector.parse_streaming_increment("random>")
|
||||
self.assertEqual(result3.normal_text, "</random>")
|
||||
self.assertEqual(result3.reasoning_text, "")
|
||||
|
||||
def test_edge_case_exact_token_match(self):
|
||||
"""
|
||||
Test edge case where buffer content exactly matches a token.
|
||||
"""
|
||||
detector = BaseReasoningFormatDetector("<think>", "</think>")
|
||||
|
||||
# Build up the exact start token character by character
|
||||
detector.parse_streaming_increment("<")
|
||||
detector.parse_streaming_increment("t")
|
||||
detector.parse_streaming_increment("h")
|
||||
detector.parse_streaming_increment("i")
|
||||
detector.parse_streaming_increment("n")
|
||||
result = detector.parse_streaming_increment("k>")
|
||||
|
||||
# Should enter reasoning mode
|
||||
self.assertEqual(result.normal_text, "")
|
||||
self.assertEqual(result.reasoning_text, "")
|
||||
self.assertTrue(detector._in_reasoning)
|
||||
self.assertTrue(detector.stripped_think_start)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user