model: qwen3-omni (thinker-only) (#10911)
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -853,6 +853,7 @@ multimodal_model_archs = [
|
|||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"Qwen3VLForConditionalGeneration",
|
"Qwen3VLForConditionalGeneration",
|
||||||
"Qwen3VLMoeForConditionalGeneration",
|
"Qwen3VLMoeForConditionalGeneration",
|
||||||
|
"Qwen3OmniMoeForConditionalGeneration",
|
||||||
"KimiVLForConditionalGeneration",
|
"KimiVLForConditionalGeneration",
|
||||||
"InternVLChatModel",
|
"InternVLChatModel",
|
||||||
"InternS1ForConditionalGeneration",
|
"InternS1ForConditionalGeneration",
|
||||||
|
|||||||
613
python/sglang/srt/configs/qwen3_omni.py
Normal file
613
python/sglang/srt/configs/qwen3_omni.py
Normal file
@@ -0,0 +1,613 @@
|
|||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.configuration_utils import layer_type_validation
|
||||||
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
|
from sglang.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen3_omni_moe_audio_encoder"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_mel_bins=128,
|
||||||
|
encoder_layers=32,
|
||||||
|
encoder_attention_heads=20,
|
||||||
|
encoder_ffn_dim=5120,
|
||||||
|
d_model=1280,
|
||||||
|
dropout=0,
|
||||||
|
attention_dropout=0,
|
||||||
|
activation_function="gelu",
|
||||||
|
activation_dropout=0,
|
||||||
|
scale_embedding=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
max_source_positions=1500,
|
||||||
|
n_window=100,
|
||||||
|
output_dim=3584,
|
||||||
|
n_window_infer=400,
|
||||||
|
conv_chunksize=500,
|
||||||
|
downsample_hidden_size=480,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.num_mel_bins = num_mel_bins
|
||||||
|
self.d_model = d_model
|
||||||
|
self.encoder_layers = encoder_layers
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.scale_embedding = (
|
||||||
|
scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
)
|
||||||
|
self.max_source_positions = max_source_positions
|
||||||
|
self.n_window = n_window
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.n_window_infer = n_window_infer
|
||||||
|
self.conv_chunksize = conv_chunksize
|
||||||
|
self.downsample_hidden_size = downsample_hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen3_omni_moe_vision_encoder"
|
||||||
|
base_config_key = "vision_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth=27,
|
||||||
|
hidden_size=1152,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
intermediate_size=4304,
|
||||||
|
num_heads=16,
|
||||||
|
in_channels=3,
|
||||||
|
patch_size=16,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
temporal_patch_size=2,
|
||||||
|
out_hidden_size=3584,
|
||||||
|
num_position_embeddings=2304,
|
||||||
|
deepstack_visual_indexes=[8, 16, 24],
|
||||||
|
initializer_range=0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.out_hidden_size = out_hidden_size
|
||||||
|
self.num_position_embeddings = num_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.deepstack_visual_indexes = deepstack_visual_indexes
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeTextConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen3_omni_moe_text"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
# Default tensor parallel plan for base model `Qwen3OmniMoeText`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=3584,
|
||||||
|
hidden_size=2048,
|
||||||
|
intermediate_size=18944,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=28,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=1000000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
sliding_window=None,
|
||||||
|
attention_dropout=0,
|
||||||
|
decoder_sparse_step=1,
|
||||||
|
moe_intermediate_size=768,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
num_experts=128,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.001,
|
||||||
|
mlp_only_layers=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
# MoE arguments
|
||||||
|
self.decoder_sparse_step = decoder_sparse_step
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeThinkerConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen3_omni_moe_thinker"
|
||||||
|
attribute_map = {
|
||||||
|
"image_token_id": "image_token_index",
|
||||||
|
"video_token_id": "video_token_index",
|
||||||
|
"audio_token_id": "audio_token_index",
|
||||||
|
}
|
||||||
|
sub_configs = {
|
||||||
|
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
|
||||||
|
"vision_config": Qwen3OmniMoeVisionEncoderConfig,
|
||||||
|
"text_config": Qwen3OmniMoeTextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_config=None,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
audio_token_id=151646,
|
||||||
|
image_token_id=151655,
|
||||||
|
video_token_id=151656,
|
||||||
|
position_id_per_seconds=25,
|
||||||
|
audio_start_token_id=151647,
|
||||||
|
user_token_id=872,
|
||||||
|
initializer_range=0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.user_token_id = user_token_id
|
||||||
|
self.position_id_per_seconds = position_id_per_seconds
|
||||||
|
self.audio_start_token_id = audio_start_token_id
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config)
|
||||||
|
elif vision_config is None:
|
||||||
|
vision_config = Qwen3OmniMoeVisionEncoderConfig()
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if isinstance(audio_config, dict):
|
||||||
|
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
|
||||||
|
elif audio_config is None:
|
||||||
|
audio_config = Qwen3OmniMoeAudioEncoderConfig()
|
||||||
|
self.audio_config = audio_config
|
||||||
|
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
text_config = Qwen3OmniMoeTextConfig(**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
text_config = Qwen3OmniMoeTextConfig()
|
||||||
|
self.text_config = text_config
|
||||||
|
self.audio_token_id = audio_token_id
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.video_token_id = video_token_id
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "qwen3_omni_moe_talker_code_predictor"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=2048,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=0.000001,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
sliding_window=None,
|
||||||
|
layer_types=None,
|
||||||
|
attention_dropout=0,
|
||||||
|
num_code_groups=32,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
(
|
||||||
|
"sliding_attention"
|
||||||
|
if self.sliding_window is not None and i >= self.max_window_layers
|
||||||
|
else "full_attention"
|
||||||
|
)
|
||||||
|
for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
||||||
|
self.num_code_groups = num_code_groups
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeTalkerTextConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "qwen3_omni_moe_talker_text"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=3072,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_hidden_layers=20,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=0.000001,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
sliding_window=None,
|
||||||
|
attention_dropout=0,
|
||||||
|
decoder_sparse_step=1,
|
||||||
|
moe_intermediate_size=384,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
num_experts=128,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.001,
|
||||||
|
mlp_only_layers=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
# MoE arguments
|
||||||
|
self.decoder_sparse_step = decoder_sparse_step
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeTalkerConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
sub_configs = {
|
||||||
|
"code_predictor_config": Qwen3OmniMoeTalkerCodePredictorConfig,
|
||||||
|
"text_config": Qwen3OmniMoeTalkerTextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
code_predictor_config=None,
|
||||||
|
text_config=None,
|
||||||
|
num_code_groups=32,
|
||||||
|
thinker_hidden_size=2048,
|
||||||
|
codec_eos_token_id=4198,
|
||||||
|
accept_hidden_layer=18,
|
||||||
|
codec_nothink_id=4203,
|
||||||
|
codec_think_bos_id=4204,
|
||||||
|
codec_think_eos_id=4205,
|
||||||
|
codec_pad_id=4196,
|
||||||
|
codec_bos_id=4197,
|
||||||
|
audio_token_id=151646,
|
||||||
|
image_token_id=151655,
|
||||||
|
video_token_id=151656,
|
||||||
|
vision_start_token_id=151652,
|
||||||
|
position_id_per_seconds=25,
|
||||||
|
audio_start_token_id=151669,
|
||||||
|
speaker_id=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if code_predictor_config is None:
|
||||||
|
code_predictor_config = {}
|
||||||
|
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig()
|
||||||
|
logger.info(
|
||||||
|
"code_predictor_config is None. Initializing code_predictor_config model with default values"
|
||||||
|
)
|
||||||
|
elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig):
|
||||||
|
self.code_predictor_config = code_predictor_config
|
||||||
|
else:
|
||||||
|
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig(
|
||||||
|
**code_predictor_config
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_config is None:
|
||||||
|
text_config = {}
|
||||||
|
self.text_config = Qwen3OmniMoeTalkerTextConfig()
|
||||||
|
logger.info(
|
||||||
|
"talker text_config is None. Initializing talker text model with default values"
|
||||||
|
)
|
||||||
|
elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig):
|
||||||
|
self.text_config = text_config
|
||||||
|
else:
|
||||||
|
self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config)
|
||||||
|
self.num_code_groups = num_code_groups
|
||||||
|
self.thinker_hidden_size = thinker_hidden_size
|
||||||
|
self.codec_eos_token_id = codec_eos_token_id
|
||||||
|
self.accept_hidden_layer = accept_hidden_layer
|
||||||
|
self.codec_nothink_id = codec_nothink_id
|
||||||
|
self.codec_think_bos_id = codec_think_bos_id
|
||||||
|
self.codec_think_eos_id = codec_think_eos_id
|
||||||
|
self.codec_pad_id = codec_pad_id
|
||||||
|
self.codec_bos_id = codec_bos_id
|
||||||
|
self.audio_token_id = audio_token_id
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.video_token_id = video_token_id
|
||||||
|
self.position_id_per_seconds = position_id_per_seconds
|
||||||
|
self.audio_start_token_id = audio_start_token_id
|
||||||
|
self.vision_start_token_id = vision_start_token_id
|
||||||
|
self.speaker_id = speaker_id
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeCode2WavConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
codebook_size=2048,
|
||||||
|
hidden_size=1024,
|
||||||
|
max_position_embeddings=8000,
|
||||||
|
rope_theta=10000,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
attention_bias=False,
|
||||||
|
sliding_window=72,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="silu",
|
||||||
|
layer_scale_initial_scale=0.01,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
num_hidden_layers=8,
|
||||||
|
num_quantizers=16,
|
||||||
|
upsample_rates=(8, 5, 4, 3),
|
||||||
|
upsampling_ratios=(2, 2),
|
||||||
|
decoder_dim=1536,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_scale_initial_scale = layer_scale_initial_scale
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_quantizers = num_quantizers
|
||||||
|
self.upsample_rates = upsample_rates
|
||||||
|
self.upsampling_ratios = upsampling_ratios
|
||||||
|
self.decoder_dim = decoder_dim
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layer_types(self):
|
||||||
|
"""
|
||||||
|
All layer in code2wav should be sliding attention
|
||||||
|
"""
|
||||||
|
return ["sliding_attention"] * self.num_hidden_layers
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "qwen3_omni_moe"
|
||||||
|
sub_configs = {
|
||||||
|
"thinker_config": Qwen3OmniMoeThinkerConfig,
|
||||||
|
"talker_config": Qwen3OmniMoeTalkerConfig,
|
||||||
|
"code2wav_config": Qwen3OmniMoeCode2WavConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
thinker_config=None,
|
||||||
|
talker_config=None,
|
||||||
|
code2wav_config=None,
|
||||||
|
enable_audio_output=True,
|
||||||
|
im_start_token_id=151644,
|
||||||
|
im_end_token_id=151645,
|
||||||
|
tts_pad_token_id=151671,
|
||||||
|
tts_bos_token_id=151672,
|
||||||
|
tts_eos_token_id=151673,
|
||||||
|
system_token_id=8948,
|
||||||
|
user_token_id=872,
|
||||||
|
assistant_token_id=77091,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if thinker_config is None:
|
||||||
|
thinker_config = {}
|
||||||
|
logger.info(
|
||||||
|
"thinker_config is None. Initializing thinker model with default values"
|
||||||
|
)
|
||||||
|
|
||||||
|
if talker_config is None:
|
||||||
|
talker_config = {}
|
||||||
|
logger.info(
|
||||||
|
"talker_config is None. Initializing talker model with default values"
|
||||||
|
)
|
||||||
|
|
||||||
|
if code2wav_config is None:
|
||||||
|
code2wav_config = {}
|
||||||
|
logger.info(
|
||||||
|
"code2wav_config is None. Initializing code2wav model with default values"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
|
||||||
|
self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
|
||||||
|
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
|
||||||
|
self.enable_audio_output = enable_audio_output
|
||||||
|
self.im_start_token_id = im_start_token_id
|
||||||
|
self.im_end_token_id = im_end_token_id
|
||||||
|
self.tts_pad_token_id = tts_pad_token_id
|
||||||
|
self.tts_bos_token_id = tts_bos_token_id
|
||||||
|
self.tts_eos_token_id = tts_eos_token_id
|
||||||
|
self.system_token_id = system_token_id
|
||||||
|
self.user_token_id = user_token_id
|
||||||
|
self.assistant_token_id = assistant_token_id
|
||||||
|
|
||||||
|
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||||
|
"""
|
||||||
|
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||||
|
itself. On specific composite models, it is under a set of valid names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, then only search for decoder config names.
|
||||||
|
"""
|
||||||
|
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||||
|
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||||
|
# added. NOTE: currently method used only by vLLM
|
||||||
|
return self.thinker_config.get_text_config()
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from transformers.modeling_rope_utils import rope_config_validation
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
|
|||||||
self.vision_start_token_id = vision_start_token_id
|
self.vision_start_token_id = vision_start_token_id
|
||||||
self.vision_end_token_id = vision_end_token_id
|
self.vision_end_token_id = vision_end_token_id
|
||||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Qwen3VLMoeConfig",
|
|
||||||
"Qwen3VLMoeVisionConfig",
|
|
||||||
"Qwen3VLConfig",
|
|
||||||
"Qwen3VLVisionConfig",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if model_type == "qwen3_omni_moe":
|
||||||
|
# For qwen3-omni
|
||||||
|
return MRotaryEmbedding.get_rope_index_qwen3_omni(
|
||||||
|
spatial_merge_size,
|
||||||
|
image_token_id,
|
||||||
|
video_token_id,
|
||||||
|
vision_start_token_id,
|
||||||
|
tokens_per_second,
|
||||||
|
input_ids,
|
||||||
|
image_grid_thw,
|
||||||
|
video_grid_thw,
|
||||||
|
second_per_grid_ts,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
|
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
|
||||||
) and video_grid_thw is not None:
|
) and video_grid_thw is not None:
|
||||||
@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
video_grid_thw, video_grid_thw[:, 0], dim=0
|
video_grid_thw, video_grid_thw[:, 0], dim=0
|
||||||
)
|
)
|
||||||
video_grid_thw[:, 0] = 1
|
video_grid_thw[:, 0] = 1
|
||||||
|
|
||||||
mrope_position_deltas = []
|
mrope_position_deltas = []
|
||||||
if input_ids is not None and (
|
if input_ids is not None and (
|
||||||
image_grid_thw is not None or video_grid_thw is not None
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
time_tensor_long = time_tensor.long()
|
time_tensor_long = time_tensor.long()
|
||||||
t_index = time_tensor_long.flatten()
|
t_index = time_tensor_long.flatten()
|
||||||
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
elif model_type in (
|
||||||
|
"qwen2_vl",
|
||||||
|
"qwen3_vl",
|
||||||
|
"qwen3_vl_moe",
|
||||||
|
):
|
||||||
t_index = (
|
t_index = (
|
||||||
torch.arange(llm_grid_t)
|
torch.arange(llm_grid_t)
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
.flatten()
|
.flatten()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unimplemented")
|
raise RuntimeError(f"Unimplemented model type: {model_type}")
|
||||||
h_index = (
|
h_index = (
|
||||||
torch.arange(llm_grid_h)
|
torch.arange(llm_grid_h)
|
||||||
.view(1, -1, 1)
|
.view(1, -1, 1)
|
||||||
@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
mrope_position_deltas = max_position_ids + 1 - s
|
mrope_position_deltas = max_position_ids + 1 - s
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rope_index_qwen3_omni(
|
||||||
|
spatial_merge_size: int,
|
||||||
|
image_token_id: int,
|
||||||
|
video_token_id: int,
|
||||||
|
vision_start_token_id: int,
|
||||||
|
tokens_per_second: Optional[int] = None,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# For qwen3-omni
|
||||||
|
audio_token_id = kwargs["audio_token_id"]
|
||||||
|
audio_start_token_id = kwargs["audio_start_token_id"]
|
||||||
|
position_id_per_seconds = kwargs["position_id_per_seconds"]
|
||||||
|
use_audio_in_video = kwargs.get("use_audio_in_video", False)
|
||||||
|
audio_seqlens = kwargs.get("audio_seqlens", None)
|
||||||
|
second_per_grids = second_per_grid_ts
|
||||||
|
|
||||||
|
mrope_position_deltas = []
|
||||||
|
if input_ids is not None and (
|
||||||
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
):
|
||||||
|
total_input_ids = input_ids
|
||||||
|
position_ids = torch.zeros(
|
||||||
|
3,
|
||||||
|
input_ids.shape[0],
|
||||||
|
input_ids.shape[1],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
image_idx, video_idx, audio_idx = 0, 0, 0
|
||||||
|
for i, current_input_ids in enumerate(total_input_ids):
|
||||||
|
image_nums, video_nums, audio_nums = 0, 0, 0
|
||||||
|
vision_start_indices = torch.argwhere(
|
||||||
|
current_input_ids == vision_start_token_id
|
||||||
|
).squeeze(1)
|
||||||
|
if vision_start_indices.numel() > 0:
|
||||||
|
vision_tokens = current_input_ids[vision_start_indices + 1]
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (
|
||||||
|
(vision_tokens == audio_start_token_id).sum()
|
||||||
|
if use_audio_in_video
|
||||||
|
else (vision_tokens == video_token_id).sum()
|
||||||
|
)
|
||||||
|
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
|
||||||
|
input_tokens = current_input_ids.tolist()
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_videos, remain_audios = (
|
||||||
|
image_nums,
|
||||||
|
video_nums,
|
||||||
|
audio_nums,
|
||||||
|
)
|
||||||
|
multimodal_nums = (
|
||||||
|
image_nums + audio_nums
|
||||||
|
if use_audio_in_video
|
||||||
|
else image_nums + video_nums + audio_nums
|
||||||
|
)
|
||||||
|
for _ in range(multimodal_nums):
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1
|
||||||
|
if len(llm_pos_ids_list) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
ed_vision_start = (
|
||||||
|
input_tokens.index(vision_start_token_id, st)
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
image_token_id in input_tokens
|
||||||
|
or video_token_id in input_tokens
|
||||||
|
)
|
||||||
|
and (remain_videos > 0 or remain_images > 0)
|
||||||
|
)
|
||||||
|
else len(input_tokens) + 1
|
||||||
|
)
|
||||||
|
ed_audio_start = (
|
||||||
|
input_tokens.index(audio_start_token_id, st)
|
||||||
|
if (audio_token_id in input_tokens and remain_audios > 0)
|
||||||
|
else len(input_tokens) + 1
|
||||||
|
)
|
||||||
|
min_ed = min(ed_vision_start, ed_audio_start)
|
||||||
|
|
||||||
|
text_len = min_ed - st
|
||||||
|
if text_len != 0:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
st_idx += text_len
|
||||||
|
# Audio in Video
|
||||||
|
if (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and ed_vision_start + 1 == ed_audio_start
|
||||||
|
):
|
||||||
|
bos_len, eos_len = 2, 2
|
||||||
|
else:
|
||||||
|
bos_len, eos_len = 1, 1
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
st_idx += bos_len
|
||||||
|
# Audio Only
|
||||||
|
if min_ed == ed_audio_start:
|
||||||
|
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
|
||||||
|
audio_seqlens[audio_idx]
|
||||||
|
)
|
||||||
|
llm_pos_ids = (
|
||||||
|
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
|
||||||
|
st += int(text_len + bos_len + audio_len + eos_len)
|
||||||
|
audio_idx += 1
|
||||||
|
remain_audios -= 1
|
||||||
|
|
||||||
|
# Image Only
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and current_input_ids[ed_vision_start + 1] == image_token_id
|
||||||
|
):
|
||||||
|
grid_t = image_grid_thw[image_idx][0]
|
||||||
|
grid_hs = image_grid_thw[:, 1]
|
||||||
|
grid_ws = image_grid_thw[:, 2]
|
||||||
|
t_index = (
|
||||||
|
torch.arange(grid_t) * 1 * position_id_per_seconds
|
||||||
|
).float()
|
||||||
|
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||||
|
st_idx,
|
||||||
|
image_idx,
|
||||||
|
spatial_merge_size,
|
||||||
|
t_index,
|
||||||
|
grid_hs,
|
||||||
|
grid_ws,
|
||||||
|
input_ids.device,
|
||||||
|
)
|
||||||
|
image_len = image_grid_thw[image_idx].prod() // (
|
||||||
|
spatial_merge_size**2
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
|
||||||
|
st += int(text_len + bos_len + image_len + eos_len)
|
||||||
|
image_idx += 1
|
||||||
|
remain_images -= 1
|
||||||
|
|
||||||
|
# Video Only
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and current_input_ids[ed_vision_start + 1] == video_token_id
|
||||||
|
):
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
t_index = (
|
||||||
|
torch.arange(grid_t)
|
||||||
|
* second_per_grids[video_idx].cpu().float()
|
||||||
|
* position_id_per_seconds
|
||||||
|
).float()
|
||||||
|
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||||
|
st_idx,
|
||||||
|
video_idx,
|
||||||
|
spatial_merge_size,
|
||||||
|
t_index,
|
||||||
|
grid_hs,
|
||||||
|
grid_ws,
|
||||||
|
input_ids.device,
|
||||||
|
)
|
||||||
|
video_len = video_grid_thw[video_idx].prod() // (
|
||||||
|
spatial_merge_size**2
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
|
||||||
|
st += int(text_len + bos_len + video_len + eos_len)
|
||||||
|
video_idx += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
|
||||||
|
# Audio in Video
|
||||||
|
elif (
|
||||||
|
min_ed == ed_vision_start
|
||||||
|
and ed_vision_start + 1 == ed_audio_start
|
||||||
|
):
|
||||||
|
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
|
||||||
|
audio_seqlens[audio_idx]
|
||||||
|
)
|
||||||
|
audio_llm_pos_ids = (
|
||||||
|
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
torch.arange(grid_t)
|
||||||
|
* second_per_grids[video_idx].cpu().float()
|
||||||
|
* position_id_per_seconds
|
||||||
|
).float()
|
||||||
|
video_llm_pos_ids = (
|
||||||
|
MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||||
|
st_idx,
|
||||||
|
video_idx,
|
||||||
|
spatial_merge_size,
|
||||||
|
t_index,
|
||||||
|
grid_hs,
|
||||||
|
grid_ws,
|
||||||
|
input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
video_data_index, audio_data_index = 0, 0
|
||||||
|
while (
|
||||||
|
video_data_index < video_llm_pos_ids.shape[-1]
|
||||||
|
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
video_llm_pos_ids[0][video_data_index]
|
||||||
|
<= audio_llm_pos_ids[0][audio_data_index]
|
||||||
|
):
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
video_llm_pos_ids[
|
||||||
|
:, video_data_index : video_data_index + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
video_data_index += 1
|
||||||
|
else:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
audio_llm_pos_ids[
|
||||||
|
:, audio_data_index : audio_data_index + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
audio_data_index += 1
|
||||||
|
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
video_llm_pos_ids[
|
||||||
|
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
audio_llm_pos_ids[
|
||||||
|
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
video_len = video_grid_thw[video_idx].prod() // (
|
||||||
|
spatial_merge_size**2
|
||||||
|
)
|
||||||
|
|
||||||
|
st += int(text_len + bos_len + audio_len + video_len + eos_len)
|
||||||
|
|
||||||
|
audio_idx += 1
|
||||||
|
video_idx += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
remain_audios -= 1
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1
|
||||||
|
if len(llm_pos_ids_list) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1
|
||||||
|
if len(llm_pos_ids_list) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(
|
||||||
|
[item.float() for item in llm_pos_ids_list], dim=1
|
||||||
|
).reshape(3, -1)
|
||||||
|
|
||||||
|
position_ids[..., i, :] = llm_positions.to(position_ids.device)
|
||||||
|
mrope_position_deltas.append(
|
||||||
|
llm_positions.max() + 1 - len(current_input_ids)
|
||||||
|
)
|
||||||
|
mrope_position_deltas = torch.tensor(
|
||||||
|
mrope_position_deltas, device=input_ids.device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
|
else:
|
||||||
|
s = input_ids.shape[1]
|
||||||
|
position_ids = torch.arange(s)
|
||||||
|
position_ids = (
|
||||||
|
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
||||||
|
)
|
||||||
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
||||||
|
-1, keepdim=True
|
||||||
|
)[0]
|
||||||
|
mrope_position_deltas = max_position_ids + 1 - s
|
||||||
|
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_rope_index_glm4v(
|
def get_rope_index_glm4v(
|
||||||
@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
# For qwen3-omni
|
||||||
|
@staticmethod
|
||||||
|
def _get_feat_extract_output_lengths(input_lengths):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||||
|
"""
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
output_lengths = (
|
||||||
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
)
|
||||||
|
return output_lengths
|
||||||
|
|
||||||
|
# For qwen3-omni
|
||||||
|
@staticmethod
|
||||||
|
def _get_llm_pos_ids_for_vision(
|
||||||
|
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
|
||||||
|
):
|
||||||
|
grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||||
|
grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||||
|
|
||||||
|
h_index = (
|
||||||
|
torch.arange(grid_h, device=device)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(len(t_index), -1, grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(grid_w, device=device)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(len(t_index), grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
|
||||||
|
|
||||||
|
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
|
||||||
|
return llm_pos_ids
|
||||||
|
|
||||||
|
|
||||||
class DualChunkRotaryEmbedding(CustomOp):
|
class DualChunkRotaryEmbedding(CustomOp):
|
||||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||||
|
|||||||
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|||||||
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
||||||
|
|
||||||
ret_input_ids = input_ids_tensor.tolist()
|
ret_input_ids = input_ids_tensor.tolist()
|
||||||
|
|
||||||
return ret_input_ids
|
return ret_input_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -507,7 +506,7 @@ def embed_mm_inputs(
|
|||||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||||
use_deepstack: bool = False,
|
use_deepstack: Dict[Modality, bool] = {},
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Embed multimodal inputs and integrate them with text token embeddings.
|
Embed multimodal inputs and integrate them with text token embeddings.
|
||||||
@@ -533,7 +532,9 @@ def embed_mm_inputs(
|
|||||||
for mm_inputs in mm_inputs_list:
|
for mm_inputs in mm_inputs_list:
|
||||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||||
|
|
||||||
embeddings, masks, deepstack_embeddings = [], [], []
|
# deepstack_embeddings: per-modality
|
||||||
|
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
|
||||||
|
|
||||||
# 2. Get multimodal embedding separately
|
# 2. Get multimodal embedding separately
|
||||||
# Try get mm embedding if any
|
# Try get mm embedding if any
|
||||||
for modality in Modality.all():
|
for modality in Modality.all():
|
||||||
@@ -549,7 +550,8 @@ def embed_mm_inputs(
|
|||||||
# "image", "video", etc
|
# "image", "video", etc
|
||||||
modality_id = modality.name.lower()
|
modality_id = modality.name.lower()
|
||||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||||
if len(items) != 0 and embedder is not None:
|
if len(items) != 0:
|
||||||
|
assert embedder is not None, f"no embedding method found for {modality}"
|
||||||
placeholder_tensor = torch.as_tensor(
|
placeholder_tensor = torch.as_tensor(
|
||||||
[item.pad_value for item in items],
|
[item.pad_value for item in items],
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
@@ -580,11 +582,12 @@ def embed_mm_inputs(
|
|||||||
items_offset_list=items_offsets,
|
items_offset_list=items_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_deepstack and embedding is not None:
|
if use_deepstack.get(modality, None) and embedding is not None:
|
||||||
embedding, deepstack_embedding = (
|
embedding, deepstack_embedding = (
|
||||||
multimodal_model.separate_deepstack_embeds(embedding)
|
multimodal_model.separate_deepstack_embeds(embedding)
|
||||||
)
|
)
|
||||||
deepstack_embeddings += [deepstack_embedding]
|
deepstack_embeddings += [deepstack_embedding]
|
||||||
|
modalities += [modality]
|
||||||
embeddings += [embedding]
|
embeddings += [embedding]
|
||||||
masks += [mask]
|
masks += [mask]
|
||||||
|
|
||||||
@@ -597,17 +600,14 @@ def embed_mm_inputs(
|
|||||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||||
inputs_embeds = input_embedding(input_ids)
|
inputs_embeds = input_embedding(input_ids)
|
||||||
|
|
||||||
# 4. scatter embeddings into input embedding
|
|
||||||
|
|
||||||
# deepstack embedding
|
# deepstack embedding
|
||||||
if use_deepstack:
|
if use_deepstack:
|
||||||
num_deepstack_embeddings = (
|
num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
|
||||||
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
|
|
||||||
)
|
|
||||||
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
|
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
|
||||||
inputs_embeds.shape[-1] * num_deepstack_embeddings,
|
inputs_embeds.shape[-1] * num_deepstack_embeddings,
|
||||||
)
|
)
|
||||||
|
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
|
||||||
input_deepstack_embeds = torch.zeros(
|
input_deepstack_embeds = torch.zeros(
|
||||||
deepstack_embedding_shape,
|
deepstack_embedding_shape,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
@@ -616,14 +616,16 @@ def embed_mm_inputs(
|
|||||||
|
|
||||||
other_info["input_deepstack_embeds"] = input_deepstack_embeds
|
other_info["input_deepstack_embeds"] = input_deepstack_embeds
|
||||||
|
|
||||||
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
|
# 4. scatter embeddings into input embedding
|
||||||
|
for i, modality, embedding, mask in zip(
|
||||||
|
range(len(embeddings)), modalities, embeddings, masks
|
||||||
|
):
|
||||||
if embedding is None or mask is None:
|
if embedding is None or mask is None:
|
||||||
continue
|
continue
|
||||||
# in-place update
|
# in-place update
|
||||||
indices = torch.where(mask.squeeze(dim=-1))[0]
|
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||||
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
if use_deepstack.get(modality, None):
|
||||||
if use_deepstack:
|
|
||||||
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
|
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
inputs_embeds.device, inputs_embeds.dtype
|
||||||
)
|
)
|
||||||
@@ -640,7 +642,7 @@ def general_mm_embed_routine(
|
|||||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||||
use_deepstack: bool = False,
|
use_deepstack: Dict[Modality, bool] = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -652,7 +654,7 @@ def general_mm_embed_routine(
|
|||||||
language_model: Base language model to use
|
language_model: Base language model to use
|
||||||
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
||||||
placeholder_tokens: Token IDs for multimodal placeholders
|
placeholder_tokens: Token IDs for multimodal placeholders
|
||||||
use_deepstack: Whether to use deepstack embeddings
|
use_deepstack: Whether to use deepstack embeddings for each modality, default False
|
||||||
**kwargs: Additional arguments passed to language model
|
**kwargs: Additional arguments passed to language model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.mm_processor and obj.contains_mm_input():
|
if self.mm_processor and obj.contains_mm_input():
|
||||||
if not isinstance(obj.image_data, list) and obj.image_data:
|
if obj.image_data is not None and not isinstance(obj.image_data, list):
|
||||||
obj.image_data = [obj.image_data]
|
obj.image_data = [obj.image_data]
|
||||||
if not isinstance(obj.audio_data, list) and obj.audio_data:
|
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
|
||||||
obj.audio_data = [obj.audio_data]
|
obj.audio_data = [obj.audio_data]
|
||||||
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||||
image_data=obj.image_data,
|
image_data=obj.image_data,
|
||||||
|
|||||||
@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
|
|||||||
@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|||||||
config: Qwen3MoeConfig,
|
config: Qwen3MoeConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
decoder_layer_type=Qwen3MoeDecoderLayer,
|
||||||
) -> None:
|
) -> None:
|
||||||
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
decoder_layer_type=Qwen3MoeDecoderLayer,
|
decoder_layer_type=decoder_layer_type,
|
||||||
alt_stream=alt_stream,
|
alt_stream=alt_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
661
python/sglang/srt/models/qwen3_omni_moe.py
Normal file
661
python/sglang/srt/models/qwen3_omni_moe.py
Normal file
@@ -0,0 +1,661 @@
|
|||||||
|
# Copyright 2025 Qwen Team
|
||||||
|
# Copyright 2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||||
|
import math
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
|
||||||
|
from sglang.srt.configs.qwen3_omni import (
|
||||||
|
Qwen3OmniMoeAudioEncoderConfig,
|
||||||
|
Qwen3OmniMoeThinkerConfig,
|
||||||
|
Qwen3OmniMoeVisionEncoderConfig,
|
||||||
|
)
|
||||||
|
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
|
||||||
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
|
||||||
|
from sglang.srt.models.qwen3_vl_moe import (
|
||||||
|
Qwen3MoeLLMModel,
|
||||||
|
Qwen3VLMoeForConditionalGeneration,
|
||||||
|
load_fused_expert_weights,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import add_prefix, logger
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3OmniMoeAudioEncoderConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.d_model
|
||||||
|
self.embed_dim = config.d_model
|
||||||
|
self.self_attn = VisionAttention(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=config.encoder_attention_heads,
|
||||||
|
projection_size=embed_dim,
|
||||||
|
use_qkv_parallel=True,
|
||||||
|
rotary_embed="normal",
|
||||||
|
proj_bias=True,
|
||||||
|
qkv_backend="fa3",
|
||||||
|
softmax_in_single_precision=False,
|
||||||
|
flatten_batch=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("attn", prefix),
|
||||||
|
)
|
||||||
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.activation_fn = ACT2FN[config.activation_function]
|
||||||
|
self.activation_dropout = config.activation_dropout
|
||||||
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||||
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(encoder_attention_heads,)`.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
x=hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
if hidden_states.dtype == torch.float16:
|
||||||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
|
hidden_states = torch.clamp(
|
||||||
|
hidden_states, min=-clamp_value, max=clamp_value
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidsPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, length, channels, max_timescale=10000):
|
||||||
|
super().__init__()
|
||||||
|
if channels % 2 != 0:
|
||||||
|
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
||||||
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||||
|
inv_timescales = torch.exp(
|
||||||
|
-log_timescale_increment * torch.arange(channels // 2).float()
|
||||||
|
)
|
||||||
|
scaled_time = (
|
||||||
|
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"positional_embedding",
|
||||||
|
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, seqlen: int):
|
||||||
|
return self.positional_embedding[:seqlen, :]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(input_lengths):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
output_lengths = (
|
||||||
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
)
|
||||||
|
return output_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
|
||||||
|
config: Qwen3OmniMoeAudioEncoderConfig
|
||||||
|
|
||||||
|
def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.dropout = config.dropout
|
||||||
|
|
||||||
|
embed_dim = config.d_model
|
||||||
|
self.num_mel_bins = config.num_mel_bins
|
||||||
|
self.max_source_positions = config.max_source_positions
|
||||||
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||||
|
self.n_window = config.n_window
|
||||||
|
self.positional_embedding = SinusoidsPositionEmbedding(
|
||||||
|
self.max_source_positions, embed_dim
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen3OmniMoeAudioEncoderLayer(config)
|
||||||
|
for _ in range(config.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_post = nn.LayerNorm(config.d_model)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
||||||
|
self.conv2d2 = nn.Conv2d(
|
||||||
|
config.downsample_hidden_size,
|
||||||
|
config.downsample_hidden_size,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv2d3 = nn.Conv2d(
|
||||||
|
config.downsample_hidden_size,
|
||||||
|
config.downsample_hidden_size,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Linear(
|
||||||
|
config.downsample_hidden_size
|
||||||
|
* ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
|
||||||
|
config.d_model,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
||||||
|
self.act = ACT2FN[config.activation_function]
|
||||||
|
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
||||||
|
self.n_window_infer = self.config.n_window_infer
|
||||||
|
self.conv_chunksize = self.config.conv_chunksize
|
||||||
|
|
||||||
|
def _freeze_parameters(self):
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self._requires_grad = False
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.conv1
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
|
self.conv1 = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
feature_lens=None,
|
||||||
|
aftercnn_lens=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
||||||
|
mel length
|
||||||
|
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
||||||
|
mel length after cnn
|
||||||
|
"""
|
||||||
|
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
|
||||||
|
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
||||||
|
|
||||||
|
chunk_lengths = torch.tensor(
|
||||||
|
[self.n_window * 2] * chunk_num.sum(),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=feature_lens.device,
|
||||||
|
)
|
||||||
|
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
||||||
|
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
||||||
|
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
||||||
|
|
||||||
|
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
||||||
|
padded_feature = nn.utils.rnn.pad_sequence(
|
||||||
|
chunk_list, batch_first=True
|
||||||
|
).transpose(1, 2)
|
||||||
|
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
|
||||||
|
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
|
||||||
|
[
|
||||||
|
torch.ones(length, dtype=torch.bool, device=padded_feature.device)
|
||||||
|
for length in feature_lens_after_cnn
|
||||||
|
],
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
padded_feature = padded_feature.unsqueeze(1)
|
||||||
|
# Split to chunk to avoid OOM during convolution
|
||||||
|
padded_embeds = []
|
||||||
|
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
||||||
|
padded_embed = F.gelu(self.conv2d1(chunk))
|
||||||
|
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
||||||
|
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
||||||
|
padded_embeds.append(padded_embed)
|
||||||
|
padded_embed = torch.cat(padded_embeds, dim=0)
|
||||||
|
b, c, f, t = padded_embed.size()
|
||||||
|
padded_embed = self.conv_out(
|
||||||
|
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
|
||||||
|
)
|
||||||
|
|
||||||
|
positional_embedding = (
|
||||||
|
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(padded_embed.dtype)
|
||||||
|
)
|
||||||
|
padded_embed = padded_embed + positional_embedding
|
||||||
|
hidden_states = padded_embed[padded_mask_after_cnn]
|
||||||
|
cu_chunk_lens = [0]
|
||||||
|
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
|
||||||
|
self.n_window_infer // (self.n_window * 2)
|
||||||
|
)
|
||||||
|
for cnn_len in aftercnn_lens:
|
||||||
|
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
|
||||||
|
remainder = cnn_len % window_aftercnn
|
||||||
|
if remainder != 0:
|
||||||
|
cu_chunk_lens += [remainder]
|
||||||
|
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
|
||||||
|
-1, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.ln_post(hidden_states)
|
||||||
|
hidden_states = self.proj1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.proj2(hidden_states)
|
||||||
|
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||||
|
"""
|
||||||
|
input_lengths = (input_lengths - 1) // 2 + 1
|
||||||
|
output_lengths = (input_lengths - 2) // 2 + 1
|
||||||
|
return input_lengths, output_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeVisionPatchMerger(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
context_dim: int,
|
||||||
|
spatial_merge_size: int = 2,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_postshuffle_norm=False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
|
self.use_postshuffle_norm = use_postshuffle_norm
|
||||||
|
self.ln_q = RMSNorm(
|
||||||
|
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
|
||||||
|
)
|
||||||
|
self.mlp = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("mlp.0", prefix),
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("mlp.2", prefix),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = (
|
||||||
|
x.view(-1, self.hidden_size)
|
||||||
|
if self.use_postshuffle_norm
|
||||||
|
else x.view(-1, x.shape[-1])
|
||||||
|
)
|
||||||
|
hidden = self.ln_q(x).view(-1, self.hidden_size)
|
||||||
|
for layer in self.mlp:
|
||||||
|
if isinstance(hidden, tuple):
|
||||||
|
hidden = hidden[0]
|
||||||
|
hidden = layer(hidden)
|
||||||
|
|
||||||
|
if isinstance(hidden, tuple):
|
||||||
|
hidden = hidden[0]
|
||||||
|
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
|
||||||
|
config: Qwen3OmniMoeVisionEncoderConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3OmniMoeVisionEncoderConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vision_config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merger = Qwen3OmniMoeVisionPatchMerger(
|
||||||
|
dim=config.out_hidden_size,
|
||||||
|
context_dim=config.hidden_size,
|
||||||
|
spatial_merge_size=config.spatial_merge_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_postshuffle_norm=False,
|
||||||
|
prefix=add_prefix("merger", prefix),
|
||||||
|
)
|
||||||
|
self.merger_list = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen3OmniMoeVisionPatchMerger(
|
||||||
|
dim=config.out_hidden_size,
|
||||||
|
context_dim=config.hidden_size,
|
||||||
|
spatial_merge_size=config.spatial_merge_size,
|
||||||
|
use_postshuffle_norm=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("merger_list", prefix),
|
||||||
|
)
|
||||||
|
for _ in range(len(config.deepstack_visual_indexes))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
del self.deepstack_merger_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deepstack_merger_list(self):
|
||||||
|
return self.merger_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
|
||||||
|
config: Qwen3OmniMoeThinkerConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3OmniMoeThinkerConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
|
||||||
|
)
|
||||||
|
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
|
||||||
|
self.visual = Qwen3OmniMoeVisionEncoder(
|
||||||
|
config.vision_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
prefix=add_prefix("visual", prefix),
|
||||||
|
)
|
||||||
|
self.pad_token_id = (
|
||||||
|
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_audio_feature(self, items: List[MultimodalDataItem]):
|
||||||
|
feature_attention_mask = torch.cat(
|
||||||
|
[item.feature_attention_mask for item in items], dim=0
|
||||||
|
).type(torch.long)
|
||||||
|
input_features = (
|
||||||
|
torch.cat([item.feature for item in items])
|
||||||
|
.type(self.audio_tower.dtype)
|
||||||
|
.to(next(self.audio_tower.parameters()).device)
|
||||||
|
)
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
|
input_features = input_features.permute(0, 2, 1)[
|
||||||
|
feature_attention_mask.bool()
|
||||||
|
].permute(1, 0)
|
||||||
|
else:
|
||||||
|
audio_feature_lengths = None
|
||||||
|
|
||||||
|
feature_lens = (
|
||||||
|
audio_feature_lengths
|
||||||
|
if audio_feature_lengths is not None
|
||||||
|
else feature_attention_mask.sum(-1)
|
||||||
|
)
|
||||||
|
audio_outputs = self.audio_tower(
|
||||||
|
input_features,
|
||||||
|
feature_lens=feature_lens,
|
||||||
|
)
|
||||||
|
audio_features = audio_outputs.last_hidden_state
|
||||||
|
|
||||||
|
return audio_features
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3VLMoeConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||||
|
config.thinker_config, quant_config=quant_config, prefix=prefix
|
||||||
|
)
|
||||||
|
self.enable_talker = False
|
||||||
|
self.pad_input_ids = self.thinker.pad_input_ids
|
||||||
|
self.forward = self.thinker.forward
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||||
|
ignore_suffixes = (
|
||||||
|
".bias",
|
||||||
|
"_bias",
|
||||||
|
".k_scale",
|
||||||
|
"_k_scale",
|
||||||
|
".v_scale",
|
||||||
|
"_v_scale",
|
||||||
|
".weight_scale",
|
||||||
|
"_weight_scale",
|
||||||
|
".input_scale",
|
||||||
|
"_input_scale",
|
||||||
|
)
|
||||||
|
|
||||||
|
is_fused_expert = False
|
||||||
|
fused_expert_params_mapping = [
|
||||||
|
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||||
|
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
num_experts = self.config.num_experts
|
||||||
|
|
||||||
|
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
||||||
|
if not hasattr(self, "_cached_params_dict"):
|
||||||
|
self._cached_params_dict = dict(self.named_parameters())
|
||||||
|
params_dict = self._cached_params_dict
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
name = name.replace(r"model.language_model.", r"model.")
|
||||||
|
|
||||||
|
if ("talker" in name or "code2wav" in name) and not self.enable_talker:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||||
|
is_fused_expert = True
|
||||||
|
expert_params_mapping = fused_expert_params_mapping
|
||||||
|
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
if "visual" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||||
|
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
|
||||||
|
# if is_pp_missing_parameter(name, self):
|
||||||
|
# continue
|
||||||
|
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Track if this is an expert weight to enable early skipping
|
||||||
|
is_expert_weight = False
|
||||||
|
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
if "visual" in name or "audio_tower" in name:
|
||||||
|
continue
|
||||||
|
# Anyway, this is an expert weight and should not be
|
||||||
|
# attempted to load as other weights later
|
||||||
|
is_expert_weight = True
|
||||||
|
name_mapped = name.replace(weight_name, param_name)
|
||||||
|
if is_fused_expert:
|
||||||
|
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
||||||
|
if "experts.gate_up_proj" in name:
|
||||||
|
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||||
|
load_fused_expert_weights(
|
||||||
|
name_mapped,
|
||||||
|
params_dict,
|
||||||
|
loaded_weight[0],
|
||||||
|
"w1",
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
load_fused_expert_weights(
|
||||||
|
name_mapped,
|
||||||
|
params_dict,
|
||||||
|
loaded_weight[1],
|
||||||
|
"w3",
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
load_fused_expert_weights(
|
||||||
|
name_mapped,
|
||||||
|
params_dict,
|
||||||
|
loaded_weight,
|
||||||
|
shard_id,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||||
|
if (
|
||||||
|
name_mapped.endswith(ignore_suffixes)
|
||||||
|
and name_mapped not in params_dict
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
param = params_dict[name_mapped]
|
||||||
|
# We should ask the weight loader to return success or
|
||||||
|
# not here since otherwise we may skip experts with
|
||||||
|
# # other available replicas.
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name_mapped,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
name = name_mapped
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if is_expert_weight:
|
||||||
|
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
||||||
|
continue
|
||||||
|
if "visual" in name or "audio_tower" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
name = name.replace(r"model.visual.", r"visual.")
|
||||||
|
name = name.replace(r"attn.out_proj.", r"attn.proj.")
|
||||||
|
|
||||||
|
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||||
|
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name in params_dict.keys():
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Loaded weight with {name=} not found in params_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = Qwen3OmniMoeForConditionalGeneration
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||||
import logging
|
import logging
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|||||||
Qwen2_5_VisionRotaryEmbedding,
|
Qwen2_5_VisionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
|
from sglang.srt.configs.qwen3_vl import (
|
||||||
|
Qwen3VLConfig,
|
||||||
|
Qwen3VLTextConfig,
|
||||||
|
Qwen3VLVisionConfig,
|
||||||
|
)
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
from sglang.srt.managers.schedule_batch import (
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
|
||||||
from sglang.srt.models.qwen3 import Qwen3Model
|
from sglang.srt.models.qwen3 import Qwen3Model
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix
|
||||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# === Vision Encoder === #
|
# === Vision Encoder === #
|
||||||
|
|
||||||
|
|
||||||
@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_VisionPatchMerger(nn.Module):
|
class Qwen3VLMoeVisionPatchMerger(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_VisionTransformer(nn.Module):
|
class Qwen3VLMoeVisionModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||||
|
# layer indexes of which layer's output should be deep-stacked
|
||||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||||
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
|
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
|
||||||
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
|
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
|
||||||
|
|
||||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||||
@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
for layer_idx in range(vision_config.depth)
|
for layer_idx in range(vision_config.depth)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.merger = Qwen3_VisionPatchMerger(
|
self.merger = Qwen3VLMoeVisionPatchMerger(
|
||||||
dim=vision_config.out_hidden_size,
|
dim=vision_config.out_hidden_size,
|
||||||
context_dim=self.hidden_size,
|
context_dim=self.hidden_size,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
self.deepstack_merger_list = nn.ModuleList(
|
self.deepstack_merger_list = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen3_VisionPatchMerger(
|
Qwen3VLMoeVisionPatchMerger(
|
||||||
dim=vision_config.out_hidden_size,
|
dim=vision_config.out_hidden_size,
|
||||||
context_dim=self.hidden_size,
|
context_dim=self.hidden_size,
|
||||||
spatial_merge_size=self.spatial_merge_size,
|
spatial_merge_size=self.spatial_merge_size,
|
||||||
@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
deepstack_feature_lists = []
|
deepstack_feature_lists = []
|
||||||
@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
|||||||
config: Qwen3VLConfig,
|
config: Qwen3VLConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
language_model_cls=Qwen3LLMModel,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.visual = Qwen3VLMoeVisionModel(
|
||||||
self.visual = Qwen3_VisionTransformer(
|
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
||||||
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
prefix=add_prefix("visual", prefix),
|
prefix=add_prefix("visual", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = Qwen3LLMModel(
|
# TODO: make it more elegant
|
||||||
config=config,
|
if language_model_cls is Qwen3LLMModel:
|
||||||
|
self.config: Qwen3VLConfig = config # for qwen3-vl
|
||||||
|
else:
|
||||||
|
self.config = config.text_config # for qwen3-omni
|
||||||
|
|
||||||
|
self.model = language_model_cls(
|
||||||
|
config=self.config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("model", prefix),
|
prefix=add_prefix("model", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head = self.model.embed_tokens
|
self.lm_head = self.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
self.config.vocab_size,
|
||||||
config.hidden_size,
|
self.config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("lm_head", prefix),
|
prefix=add_prefix("lm_head", prefix),
|
||||||
)
|
)
|
||||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(self.config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
|
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
|
||||||
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
|
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
|
||||||
@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
|||||||
# deepstack
|
# deepstack
|
||||||
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
||||||
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
||||||
|
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
|
||||||
@property
|
|
||||||
def use_deepstack(self) -> bool:
|
|
||||||
return hasattr(self, "deepstack_visual_indexes")
|
|
||||||
|
|
||||||
def separate_deepstack_embeds(self, embedding):
|
def separate_deepstack_embeds(self, embedding):
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -14,29 +14,19 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||||
import logging
|
import logging
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache
|
||||||
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
from transformers import BatchFeature
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
||||||
Qwen2_5_VisionRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
|
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_pp_group,
|
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||||
@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
||||||
from sglang.srt.models.qwen3_vl import (
|
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||||||
Qwen3_VisionTransformer,
|
|
||||||
Qwen3VLForConditionalGeneration,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import add_prefix
|
|
||||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
config: Qwen3VLMoeConfig,
|
config: Qwen3VLMoeTextConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
|
||||||
# in qwen-vl, last dim is the same
|
|
||||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
|
||||||
self.visual.dtype
|
|
||||||
)
|
|
||||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
|
||||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
|
||||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# process deepstack
|
# process deepstack
|
||||||
if input_deepstack_embeds is not None and layer_idx in range(3):
|
if input_deepstack_embeds is not None and layer_idx < 3:
|
||||||
sep = self.hidden_size * layer_idx
|
sep = self.hidden_size * layer_idx
|
||||||
hidden_states.add_(
|
hidden_states.add_(
|
||||||
input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
||||||
@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
return hidden_states, aux_hidden_states
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def load_fused_expert_weights(
|
||||||
|
name: str,
|
||||||
|
params_dict: dict,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
shard_id: str,
|
||||||
|
num_experts: int,
|
||||||
|
):
|
||||||
|
param = params_dict[name]
|
||||||
|
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
ep_rank = get_tensor_model_parallel_rank()
|
||||||
|
ep_size = get_moe_expert_parallel_world_size()
|
||||||
|
if ep_size == 1:
|
||||||
|
for expert_id in range(num_experts):
|
||||||
|
curr_expert_weight = loaded_weight[expert_id]
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
curr_expert_weight,
|
||||||
|
name,
|
||||||
|
shard_id,
|
||||||
|
expert_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
experts_per_ep = num_experts // ep_size
|
||||||
|
start_expert = ep_rank * experts_per_ep
|
||||||
|
end_expert = (
|
||||||
|
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
||||||
|
curr_expert_weight = loaded_weight[expert_id]
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
curr_expert_weight,
|
||||||
|
name,
|
||||||
|
shard_id,
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
config: Qwen3VLMoeConfig,
|
config: Qwen3VLMoeConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
language_model_cls=Qwen3MoeLLMModel,
|
||||||
):
|
):
|
||||||
super(Qwen3VLForConditionalGeneration, self).__init__()
|
super().__init__(config, quant_config, prefix, language_model_cls)
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.visual = Qwen3_VisionTransformer(
|
|
||||||
config.vision_config,
|
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
||||||
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
|
||||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("visual", prefix),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = Qwen3MoeLLMModel(
|
|
||||||
config=config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("model", prefix),
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.model.embed_tokens
|
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("lm_head", prefix),
|
|
||||||
)
|
|
||||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
||||||
|
|
||||||
# deepstack
|
|
||||||
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
|
||||||
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_deepstack(self) -> bool:
|
|
||||||
return hasattr(self, "deepstack_visual_indexes")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
get_embedding: bool = False,
|
|
||||||
):
|
|
||||||
"""Run forward pass for Qwen3-VL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
||||||
batch.
|
|
||||||
positions: Flattened (concatenated) position ids corresponding to a
|
|
||||||
batch.
|
|
||||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
|
||||||
opensource models), the shape will be `(3, seq_len)`,
|
|
||||||
otherwise it will be `(seq_len,).
|
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
|
||||||
"""
|
|
||||||
if self.is_mrope_enabled:
|
|
||||||
positions = forward_batch.mrope_positions
|
|
||||||
|
|
||||||
if not (
|
|
||||||
forward_batch.forward_mode.is_decode()
|
|
||||||
or not forward_batch.contains_image_inputs()
|
|
||||||
):
|
|
||||||
if self.is_mrope_enabled:
|
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
||||||
"multimodal section rotary embedding requires "
|
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = general_mm_embed_routine(
|
|
||||||
input_ids=input_ids,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
language_model=self.model,
|
|
||||||
multimodal_model=self,
|
|
||||||
positions=positions,
|
|
||||||
use_deepstack=self.use_deepstack,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_embedding:
|
|
||||||
return self.logits_processor(
|
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.pooler(hidden_states, forward_batch)
|
|
||||||
|
|
||||||
def load_fused_expert_weights(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
params_dict: dict,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
shard_id: str,
|
|
||||||
num_experts: int,
|
|
||||||
):
|
|
||||||
param = params_dict[name]
|
|
||||||
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
ep_rank = get_tensor_model_parallel_rank()
|
|
||||||
ep_size = get_moe_expert_parallel_world_size()
|
|
||||||
if ep_size == 1:
|
|
||||||
for expert_id in range(num_experts):
|
|
||||||
curr_expert_weight = loaded_weight[expert_id]
|
|
||||||
weight_loader(
|
|
||||||
param,
|
|
||||||
curr_expert_weight,
|
|
||||||
name,
|
|
||||||
shard_id,
|
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
experts_per_ep = num_experts // ep_size
|
|
||||||
start_expert = ep_rank * experts_per_ep
|
|
||||||
end_expert = (
|
|
||||||
(ep_rank + 1) * experts_per_ep
|
|
||||||
if ep_rank != ep_size - 1
|
|
||||||
else num_experts
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
|
||||||
curr_expert_weight = loaded_weight[expert_id]
|
|
||||||
weight_loader(
|
|
||||||
param,
|
|
||||||
curr_expert_weight,
|
|
||||||
name,
|
|
||||||
shard_id,
|
|
||||||
idx,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
self._cached_params_dict = dict(self.named_parameters())
|
self._cached_params_dict = dict(self.named_parameters())
|
||||||
params_dict = self._cached_params_dict
|
params_dict = self._cached_params_dict
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "language_model" in name:
|
name = name.replace(r"model.language_model.", r"model.")
|
||||||
name = name.replace(r"model.language_model.", r"model.")
|
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||||
@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
||||||
if "experts.gate_up_proj" in name:
|
if "experts.gate_up_proj" in name:
|
||||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||||
self.load_fused_expert_weights(
|
load_fused_expert_weights(
|
||||||
name_mapped,
|
name_mapped,
|
||||||
params_dict,
|
params_dict,
|
||||||
loaded_weight[0],
|
loaded_weight[0],
|
||||||
"w1",
|
"w1",
|
||||||
num_experts,
|
num_experts,
|
||||||
)
|
)
|
||||||
self.load_fused_expert_weights(
|
load_fused_expert_weights(
|
||||||
name_mapped,
|
name_mapped,
|
||||||
params_dict,
|
params_dict,
|
||||||
loaded_weight[1],
|
loaded_weight[1],
|
||||||
@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
num_experts,
|
num_experts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.load_fused_expert_weights(
|
load_fused_expert_weights(
|
||||||
name_mapped,
|
name_mapped,
|
||||||
params_dict,
|
params_dict,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
|
|||||||
@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
):
|
):
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
self._processor = _processor
|
self._processor = _processor
|
||||||
self.arch = hf_config.architectures[0]
|
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.transport_mode = transport_mode
|
self.transport_mode = transport_mode
|
||||||
|
|
||||||
@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
"input_features": Modality.AUDIO,
|
"input_features": Modality.AUDIO,
|
||||||
"input_features_mask": Modality.AUDIO,
|
"input_features_mask": Modality.AUDIO,
|
||||||
"audio_attention_mask": Modality.AUDIO,
|
"audio_attention_mask": Modality.AUDIO,
|
||||||
|
"feature_attention_mask": Modality.AUDIO,
|
||||||
# Video-related attributes
|
# Video-related attributes
|
||||||
"pixel_values_videos": Modality.VIDEO,
|
"pixel_values_videos": Modality.VIDEO,
|
||||||
"second_per_grid_ts": Modality.VIDEO,
|
"second_per_grid_ts": Modality.VIDEO,
|
||||||
@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
if self._processor.__class__.__name__ in {
|
if self._processor.__class__.__name__ in {
|
||||||
"Gemma3nProcessor",
|
"Gemma3nProcessor",
|
||||||
"Qwen2AudioProcessor",
|
"Qwen2AudioProcessor",
|
||||||
|
"Qwen3OmniMoeProcessor",
|
||||||
}:
|
}:
|
||||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||||
kwargs["audio"] = audios
|
kwargs["audio"] = audios
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
|
|||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||||
|
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
|
||||||
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||||||
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
||||||
from sglang.srt.multimodal.processors.base_processor import (
|
from sglang.srt.multimodal.processors.base_processor import (
|
||||||
@@ -209,22 +210,31 @@ async def preprocess_video(
|
|||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
# Compatible with Qwen2VL and Qwen2_5VL
|
# Compatible with Qwen-VL & Qwen-Omni Series
|
||||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
class QwenVLImageProcessor(SGLangBaseProcessor):
|
||||||
models = [
|
models = [
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
Qwen3VLForConditionalGeneration,
|
Qwen3VLForConditionalGeneration,
|
||||||
Qwen3VLMoeForConditionalGeneration,
|
Qwen3VLMoeForConditionalGeneration,
|
||||||
|
Qwen3OmniMoeForConditionalGeneration,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
|
self.model_type = hf_config.model_type
|
||||||
|
if hf_config.model_type == "qwen3_omni_moe":
|
||||||
|
hf_config = hf_config.thinker_config
|
||||||
|
|
||||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
# The regex that matches expanded image tokens.
|
|
||||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||||
|
|
||||||
|
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
|
||||||
|
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
|
||||||
|
|
||||||
self.NUM_TOKEN_PER_FRAME = 770
|
self.NUM_TOKEN_PER_FRAME = 770
|
||||||
self.IMAGE_FACTOR = 28
|
self.IMAGE_FACTOR = 28
|
||||||
self.MIN_PIXELS = 4 * 28 * 28
|
self.MIN_PIXELS = 4 * 28 * 28
|
||||||
@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
self.mm_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||||
image_token_id=hf_config.image_token_id,
|
image_token_id=hf_config.image_token_id,
|
||||||
|
# The regex that matches expanded image tokens.
|
||||||
image_token_regex=re.compile(
|
image_token_regex=re.compile(
|
||||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||||
),
|
),
|
||||||
video_token_id=hf_config.video_token_id,
|
video_token_id=hf_config.video_token_id,
|
||||||
|
audio_token_id=self.audio_token_id,
|
||||||
).build(_processor)
|
).build(_processor)
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
video_data=request_obj.video_data,
|
video_data=request_obj.video_data,
|
||||||
|
audio_data=request_obj.audio_data,
|
||||||
multimodal_tokens=self.mm_tokens,
|
multimodal_tokens=self.mm_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
base_output, self.mm_tokens
|
base_output, self.mm_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
audio_feature_lengths = None
|
||||||
|
|
||||||
|
if self.model_type == "qwen3_omni_moe":
|
||||||
|
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
|
||||||
|
if audio_item:
|
||||||
|
audio_feature_lengths = torch.sum(
|
||||||
|
audio_item.feature_attention_mask, dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
|
||||||
|
ret, "video_second_per_grid", None
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = input_ids.flatten()
|
input_ids = input_ids.flatten()
|
||||||
|
|
||||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||||
image_token_id=self.mm_tokens.image_token_id,
|
image_token_id=self.mm_tokens.image_token_id,
|
||||||
video_token_id=self.mm_tokens.video_token_id,
|
video_token_id=self.mm_tokens.video_token_id,
|
||||||
vision_start_token_id=self.vision_start_token_id,
|
vision_start_token_id=self.vision_start_token_id,
|
||||||
model_type=self.hf_config.model_type,
|
model_type=self.model_type,
|
||||||
tokens_per_second=getattr(
|
tokens_per_second=getattr(
|
||||||
self.hf_config.vision_config, "tokens_per_second", None
|
self.hf_config.vision_config, "tokens_per_second", None
|
||||||
),
|
),
|
||||||
input_ids=input_ids.unsqueeze(0),
|
input_ids=input_ids.unsqueeze(0),
|
||||||
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
||||||
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
||||||
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
use_audio_in_video=False,
|
||||||
|
audio_seqlens=audio_feature_lengths,
|
||||||
|
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
|
||||||
|
audio_start_token_id=self.audio_start_token_id,
|
||||||
|
position_id_per_seconds=getattr(
|
||||||
|
self.hf_config, "position_id_per_seconds", None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
mrope_positions = mrope_positions.squeeze(1)
|
mrope_positions = mrope_positions.squeeze(1)
|
||||||
|
|
||||||
@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
"im_end_id": self.IM_END_TOKEN_ID,
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
"im_token_id": self.mm_tokens.image_token_id,
|
"im_token_id": self.mm_tokens.image_token_id,
|
||||||
"video_token_id": self.mm_tokens.video_token_id,
|
"video_token_id": self.mm_tokens.video_token_id,
|
||||||
|
"audio_token_id": self.mm_tokens.audio_token_id,
|
||||||
"mrope_positions": mrope_positions,
|
"mrope_positions": mrope_positions,
|
||||||
"mrope_position_delta": mrope_position_delta,
|
"mrope_position_delta": mrope_position_delta,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
del (
|
del (
|
||||||
TestOpenAIOmniServerBase,
|
TestOpenAIMLLMServerBase,
|
||||||
ImageOpenAITestMixin,
|
ImageOpenAITestMixin,
|
||||||
VideoOpenAITestMixin,
|
VideoOpenAITestMixin,
|
||||||
AudioOpenAITestMixin,
|
AudioOpenAITestMixin,
|
||||||
|
OmniOpenAITestMixin,
|
||||||
)
|
)
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3OmniServer(OmniOpenAITestMixin):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[ # workaround to fit into H100
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.90",
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--disable-fast-image-processor",
|
||||||
|
"--grammar-backend",
|
||||||
|
"none",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
del (
|
del (
|
||||||
TestOpenAIOmniServerBase,
|
TestOpenAIMLLMServerBase,
|
||||||
ImageOpenAITestMixin,
|
ImageOpenAITestMixin,
|
||||||
VideoOpenAITestMixin,
|
VideoOpenAITestMixin,
|
||||||
AudioOpenAITestMixin,
|
AudioOpenAITestMixin,
|
||||||
|
OmniOpenAITestMixin,
|
||||||
)
|
)
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
|
|||||||
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
|
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIOmniServerBase(CustomTestCase):
|
class TestOpenAIMLLMServerBase(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = ""
|
cls.model = ""
|
||||||
@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
|
|||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
|
class AudioOpenAITestMixin(TestOpenAIMLLMServerBase):
|
||||||
|
def verify_speech_recognition_response(self, text):
|
||||||
|
check_list = [
|
||||||
|
"thank you",
|
||||||
|
"it's a privilege to be here",
|
||||||
|
"leader",
|
||||||
|
"science",
|
||||||
|
"art",
|
||||||
|
]
|
||||||
|
for check_word in check_list:
|
||||||
|
assert (
|
||||||
|
check_word in text.lower()
|
||||||
|
), f"audio_response: |{text}| should contain |{check_word}|"
|
||||||
|
|
||||||
def prepare_audio_messages(self, prompt, audio_file_name):
|
def prepare_audio_messages(self, prompt, audio_file_name):
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
"Listen to this audio and write down the audio transcription in English.",
|
"Listen to this audio and write down the audio transcription in English.",
|
||||||
category="speech",
|
category="speech",
|
||||||
)
|
)
|
||||||
check_list = [
|
self.verify_speech_recognition_response(audio_response)
|
||||||
"thank you",
|
|
||||||
"it's a privilege to be here",
|
|
||||||
"leader",
|
|
||||||
"science",
|
|
||||||
"art",
|
|
||||||
]
|
|
||||||
for check_word in check_list:
|
|
||||||
assert (
|
|
||||||
check_word in audio_response
|
|
||||||
), f"audio_response: |{audio_response}| should contain |{check_word}|"
|
|
||||||
|
|
||||||
def test_audio_ambient_completion(self):
|
def test_audio_ambient_completion(self):
|
||||||
# bird song
|
# bird song
|
||||||
@@ -138,7 +142,79 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
assert "bird" in audio_response
|
assert "bird" in audio_response
|
||||||
|
|
||||||
|
|
||||||
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
class ImageOpenAITestMixin(TestOpenAIMLLMServerBase):
|
||||||
|
def run_decode_with_image(self, image_id):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if image_id == 0:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif image_id == 1:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a sentence.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
**(self.get_vision_request_kwargs()),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
|
||||||
|
def test_mixed_batch(self):
|
||||||
|
image_ids = [0, 1, 2] * 4
|
||||||
|
with ThreadPoolExecutor(4) as executor:
|
||||||
|
list(executor.map(self.run_decode_with_image, image_ids))
|
||||||
|
|
||||||
|
def verify_single_image_response(self, response):
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
|
||||||
|
# `driver` is for gemma-3-it
|
||||||
|
assert (
|
||||||
|
"man" in text or "person" or "driver" in text
|
||||||
|
), f"text: {text}, should contain man, person or driver"
|
||||||
|
assert (
|
||||||
|
"cab" in text
|
||||||
|
or "taxi" in text
|
||||||
|
or "SUV" in text
|
||||||
|
or "vehicle" in text
|
||||||
|
or "car" in text
|
||||||
|
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
|
||||||
|
# MiniCPMO fails to recognize `iron`, but `hanging`
|
||||||
|
assert (
|
||||||
|
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
|
||||||
|
), f"text: {text}, should contain iron, hang, cloth or holding"
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def test_single_image_chat_completion(self):
|
def test_single_image_chat_completion(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
@@ -163,34 +239,11 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
**(self.get_vision_request_kwargs()),
|
**(self.get_vision_request_kwargs()),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.choices[0].message.role == "assistant"
|
print("-" * 30)
|
||||||
text = response.choices[0].message.content
|
print(f"Single image response:\n{response.choices[0].message.content}")
|
||||||
assert isinstance(text, str)
|
print("-" * 30)
|
||||||
# `driver` is for gemma-3-it
|
|
||||||
assert (
|
self.verify_single_image_response(response)
|
||||||
"man" in text or "person" or "driver" in text
|
|
||||||
), f"text: {text}, should contain man, person or driver"
|
|
||||||
assert (
|
|
||||||
"cab" in text
|
|
||||||
or "taxi" in text
|
|
||||||
or "SUV" in text
|
|
||||||
or "vehicle" in text
|
|
||||||
or "car" in text
|
|
||||||
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
|
|
||||||
# MiniCPMO fails to recognize `iron`, but `hanging`
|
|
||||||
assert (
|
|
||||||
"iron" in text
|
|
||||||
or "hang" in text
|
|
||||||
or "cloth" in text
|
|
||||||
or "coat" in text
|
|
||||||
or "holding" in text
|
|
||||||
or "outfit" in text
|
|
||||||
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
|
|
||||||
assert response.id
|
|
||||||
assert response.created
|
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
|
|
||||||
def test_multi_turn_chat_completion(self):
|
def test_multi_turn_chat_completion(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "I have two very different images. They are not related at all. "
|
"text": "I have two very different images. Please describe them.",
|
||||||
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def _test_mixed_image_audio_chat_completion(self):
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="default",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "audio_url",
|
|
||||||
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "Please describe the image in one sentence, and then write down the audio transcription in English.",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
**(self.get_vision_request_kwargs()),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.choices[0].message.role == "assistant"
|
|
||||||
text = response.choices[0].message.content
|
|
||||||
assert isinstance(text, str)
|
|
||||||
print("-" * 30)
|
|
||||||
print(f"Mixed image & audio response:\n{text}")
|
|
||||||
print("-" * 30)
|
|
||||||
assert (
|
|
||||||
"man" in text
|
|
||||||
or "cab" in text
|
|
||||||
or "SUV" in text
|
|
||||||
or "taxi" in text
|
|
||||||
or "car" in text
|
|
||||||
), f"text: {text}, should contain man, cab, SUV, taxi or car"
|
|
||||||
check_list = [
|
|
||||||
"thank you",
|
|
||||||
"it's a privilege to be here",
|
|
||||||
"leader",
|
|
||||||
"science",
|
|
||||||
"art",
|
|
||||||
]
|
|
||||||
for check_word in check_list:
|
|
||||||
assert (
|
|
||||||
check_word in text
|
|
||||||
), f"text: |{text}| should contain |{check_word}|"
|
|
||||||
assert response.id
|
|
||||||
assert response.created
|
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
|
|
||||||
def prepare_video_images_messages(self, video_path):
|
def prepare_video_images_messages(self, video_path):
|
||||||
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
|
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
|
||||||
# the size of the video embeds differs from the `modality` argument when preprocessed
|
# the size of the video embeds differs from the `modality` argument when preprocessed
|
||||||
@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
self.assertGreater(len(video_response), 0)
|
self.assertGreater(len(video_response), 0)
|
||||||
|
|
||||||
|
|
||||||
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
|
class VideoOpenAITestMixin(TestOpenAIMLLMServerBase):
|
||||||
def prepare_video_messages(self, video_path):
|
def prepare_video_messages(self, video_path):
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
|
|||||||
), f"video_response: {video_response}, should contain 'black' or 'dark'"
|
), f"video_response: {video_response}, should contain 'black' or 'dark'"
|
||||||
self.assertIsNotNone(video_response)
|
self.assertIsNotNone(video_response)
|
||||||
self.assertGreater(len(video_response), 0)
|
self.assertGreater(len(video_response), 0)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniOpenAITestMixin(
|
||||||
|
ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin
|
||||||
|
):
|
||||||
|
def test_mixed_modality_chat_completion(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Mixed modality response:\n{text}")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
self.verify_single_image_response(response=response)
|
||||||
|
self.verify_speech_recognition_response(text=text)
|
||||||
|
|||||||
Reference in New Issue
Block a user