model: qwen3-omni (thinker-only) (#10911)
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -853,6 +853,7 @@ multimodal_model_archs = [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"Qwen3OmniMoeForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
"InternS1ForConditionalGeneration",
|
||||
|
||||
613
python/sglang/srt/configs/qwen3_omni.py
Normal file
613
python/sglang/srt/configs/qwen3_omni.py
Normal file
@@ -0,0 +1,613 @@
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.configuration_utils import layer_type_validation
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig):
|
||||
model_type = "qwen3_omni_moe_audio_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mel_bins=128,
|
||||
encoder_layers=32,
|
||||
encoder_attention_heads=20,
|
||||
encoder_ffn_dim=5120,
|
||||
d_model=1280,
|
||||
dropout=0,
|
||||
attention_dropout=0,
|
||||
activation_function="gelu",
|
||||
activation_dropout=0,
|
||||
scale_embedding=False,
|
||||
initializer_range=0.02,
|
||||
max_source_positions=1500,
|
||||
n_window=100,
|
||||
output_dim=3584,
|
||||
n_window_infer=400,
|
||||
conv_chunksize=500,
|
||||
downsample_hidden_size=480,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.d_model = d_model
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_function = activation_function
|
||||
self.activation_dropout = activation_dropout
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_embedding = (
|
||||
scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
)
|
||||
self.max_source_positions = max_source_positions
|
||||
self.n_window = n_window
|
||||
self.output_dim = output_dim
|
||||
self.n_window_infer = n_window_infer
|
||||
self.conv_chunksize = conv_chunksize
|
||||
self.downsample_hidden_size = downsample_hidden_size
|
||||
|
||||
|
||||
class Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig):
|
||||
model_type = "qwen3_omni_moe_vision_encoder"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=27,
|
||||
hidden_size=1152,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
intermediate_size=4304,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=16,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
out_hidden_size=3584,
|
||||
num_position_embeddings=2304,
|
||||
deepstack_visual_indexes=[8, 16, 24],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.num_position_embeddings = num_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.deepstack_visual_indexes = deepstack_visual_indexes
|
||||
|
||||
|
||||
class Qwen3OmniMoeTextConfig(PretrainedConfig):
|
||||
model_type = "qwen3_omni_moe_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3OmniMoeText`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=3584,
|
||||
hidden_size=2048,
|
||||
intermediate_size=18944,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=28,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
sliding_window=None,
|
||||
attention_dropout=0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=768,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=128,
|
||||
norm_topk_prob=True,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
|
||||
class Qwen3OmniMoeThinkerConfig(PretrainedConfig):
|
||||
model_type = "qwen3_omni_moe_thinker"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
|
||||
"vision_config": Qwen3OmniMoeVisionEncoderConfig,
|
||||
"text_config": Qwen3OmniMoeTextConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_config=None,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
audio_token_id=151646,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
position_id_per_seconds=25,
|
||||
audio_start_token_id=151647,
|
||||
user_token_id=872,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.user_token_id = user_token_id
|
||||
self.position_id_per_seconds = position_id_per_seconds
|
||||
self.audio_start_token_id = audio_start_token_id
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = Qwen3OmniMoeVisionEncoderConfig()
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(audio_config, dict):
|
||||
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
|
||||
elif audio_config is None:
|
||||
audio_config = Qwen3OmniMoeAudioEncoderConfig()
|
||||
self.audio_config = audio_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config = Qwen3OmniMoeTextConfig(**text_config)
|
||||
elif text_config is None:
|
||||
text_config = Qwen3OmniMoeTextConfig()
|
||||
self.text_config = text_config
|
||||
self.audio_token_id = audio_token_id
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
|
||||
class Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig):
|
||||
|
||||
model_type = "qwen3_omni_moe_talker_code_predictor"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=2048,
|
||||
hidden_size=1024,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=0.000001,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
sliding_window=None,
|
||||
layer_types=None,
|
||||
attention_dropout=0,
|
||||
num_code_groups=32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
(
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
)
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
||||
self.num_code_groups = num_code_groups
|
||||
|
||||
|
||||
class Qwen3OmniMoeTalkerTextConfig(PretrainedConfig):
|
||||
|
||||
model_type = "qwen3_omni_moe_talker_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=3072,
|
||||
hidden_size=1024,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=20,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=0.000001,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
sliding_window=None,
|
||||
attention_dropout=0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=384,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=128,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
|
||||
class Qwen3OmniMoeTalkerConfig(PretrainedConfig):
|
||||
|
||||
sub_configs = {
|
||||
"code_predictor_config": Qwen3OmniMoeTalkerCodePredictorConfig,
|
||||
"text_config": Qwen3OmniMoeTalkerTextConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_predictor_config=None,
|
||||
text_config=None,
|
||||
num_code_groups=32,
|
||||
thinker_hidden_size=2048,
|
||||
codec_eos_token_id=4198,
|
||||
accept_hidden_layer=18,
|
||||
codec_nothink_id=4203,
|
||||
codec_think_bos_id=4204,
|
||||
codec_think_eos_id=4205,
|
||||
codec_pad_id=4196,
|
||||
codec_bos_id=4197,
|
||||
audio_token_id=151646,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
position_id_per_seconds=25,
|
||||
audio_start_token_id=151669,
|
||||
speaker_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if code_predictor_config is None:
|
||||
code_predictor_config = {}
|
||||
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig()
|
||||
logger.info(
|
||||
"code_predictor_config is None. Initializing code_predictor_config model with default values"
|
||||
)
|
||||
elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig):
|
||||
self.code_predictor_config = code_predictor_config
|
||||
else:
|
||||
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig(
|
||||
**code_predictor_config
|
||||
)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
self.text_config = Qwen3OmniMoeTalkerTextConfig()
|
||||
logger.info(
|
||||
"talker text_config is None. Initializing talker text model with default values"
|
||||
)
|
||||
elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig):
|
||||
self.text_config = text_config
|
||||
else:
|
||||
self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config)
|
||||
self.num_code_groups = num_code_groups
|
||||
self.thinker_hidden_size = thinker_hidden_size
|
||||
self.codec_eos_token_id = codec_eos_token_id
|
||||
self.accept_hidden_layer = accept_hidden_layer
|
||||
self.codec_nothink_id = codec_nothink_id
|
||||
self.codec_think_bos_id = codec_think_bos_id
|
||||
self.codec_think_eos_id = codec_think_eos_id
|
||||
self.codec_pad_id = codec_pad_id
|
||||
self.codec_bos_id = codec_bos_id
|
||||
self.audio_token_id = audio_token_id
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.position_id_per_seconds = position_id_per_seconds
|
||||
self.audio_start_token_id = audio_start_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.speaker_id = speaker_id
|
||||
|
||||
|
||||
class Qwen3OmniMoeCode2WavConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codebook_size=2048,
|
||||
hidden_size=1024,
|
||||
max_position_embeddings=8000,
|
||||
rope_theta=10000,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
attention_bias=False,
|
||||
sliding_window=72,
|
||||
intermediate_size=3072,
|
||||
hidden_act="silu",
|
||||
layer_scale_initial_scale=0.01,
|
||||
rms_norm_eps=1e-5,
|
||||
num_hidden_layers=8,
|
||||
num_quantizers=16,
|
||||
upsample_rates=(8, 5, 4, 3),
|
||||
upsampling_ratios=(2, 2),
|
||||
decoder_dim=1536,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.codebook_size = codebook_size
|
||||
self.hidden_size = hidden_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.attention_bias = attention_bias
|
||||
self.sliding_window = sliding_window
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_scale_initial_scale = layer_scale_initial_scale
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_quantizers = num_quantizers
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsampling_ratios = upsampling_ratios
|
||||
self.decoder_dim = decoder_dim
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
@property
|
||||
def layer_types(self):
|
||||
"""
|
||||
All layer in code2wav should be sliding attention
|
||||
"""
|
||||
return ["sliding_attention"] * self.num_hidden_layers
|
||||
|
||||
|
||||
class Qwen3OmniMoeConfig(PretrainedConfig):
|
||||
|
||||
model_type = "qwen3_omni_moe"
|
||||
sub_configs = {
|
||||
"thinker_config": Qwen3OmniMoeThinkerConfig,
|
||||
"talker_config": Qwen3OmniMoeTalkerConfig,
|
||||
"code2wav_config": Qwen3OmniMoeCode2WavConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
thinker_config=None,
|
||||
talker_config=None,
|
||||
code2wav_config=None,
|
||||
enable_audio_output=True,
|
||||
im_start_token_id=151644,
|
||||
im_end_token_id=151645,
|
||||
tts_pad_token_id=151671,
|
||||
tts_bos_token_id=151672,
|
||||
tts_eos_token_id=151673,
|
||||
system_token_id=8948,
|
||||
user_token_id=872,
|
||||
assistant_token_id=77091,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if thinker_config is None:
|
||||
thinker_config = {}
|
||||
logger.info(
|
||||
"thinker_config is None. Initializing thinker model with default values"
|
||||
)
|
||||
|
||||
if talker_config is None:
|
||||
talker_config = {}
|
||||
logger.info(
|
||||
"talker_config is None. Initializing talker model with default values"
|
||||
)
|
||||
|
||||
if code2wav_config is None:
|
||||
code2wav_config = {}
|
||||
logger.info(
|
||||
"code2wav_config is None. Initializing code2wav model with default values"
|
||||
)
|
||||
|
||||
self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
|
||||
self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
|
||||
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
|
||||
self.enable_audio_output = enable_audio_output
|
||||
self.im_start_token_id = im_start_token_id
|
||||
self.im_end_token_id = im_end_token_id
|
||||
self.tts_pad_token_id = tts_pad_token_id
|
||||
self.tts_bos_token_id = tts_bos_token_id
|
||||
self.tts_eos_token_id = tts_eos_token_id
|
||||
self.system_token_id = system_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.assistant_token_id = assistant_token_id
|
||||
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Qwen3VLMoeConfig",
|
||||
"Qwen3VLMoeVisionConfig",
|
||||
"Qwen3VLConfig",
|
||||
"Qwen3VLVisionConfig",
|
||||
]
|
||||
|
||||
@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if model_type == "qwen3_omni_moe":
|
||||
# For qwen3-omni
|
||||
return MRotaryEmbedding.get_rope_index_qwen3_omni(
|
||||
spatial_merge_size,
|
||||
image_token_id,
|
||||
video_token_id,
|
||||
vision_start_token_id,
|
||||
tokens_per_second,
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
**kwargs,
|
||||
)
|
||||
if (
|
||||
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
|
||||
) and video_grid_thw is not None:
|
||||
@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
video_grid_thw, video_grid_thw[:, 0], dim=0
|
||||
)
|
||||
video_grid_thw[:, 0] = 1
|
||||
|
||||
mrope_position_deltas = []
|
||||
if input_ids is not None and (
|
||||
image_grid_thw is not None or video_grid_thw is not None
|
||||
@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
time_tensor_long = time_tensor.long()
|
||||
t_index = time_tensor_long.flatten()
|
||||
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
||||
elif model_type in (
|
||||
"qwen2_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
):
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
.flatten()
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unimplemented")
|
||||
raise RuntimeError(f"Unimplemented model type: {model_type}")
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
mrope_position_deltas = max_position_ids + 1 - s
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
@staticmethod
|
||||
def get_rope_index_qwen3_omni(
|
||||
spatial_merge_size: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
tokens_per_second: Optional[int] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# For qwen3-omni
|
||||
audio_token_id = kwargs["audio_token_id"]
|
||||
audio_start_token_id = kwargs["audio_start_token_id"]
|
||||
position_id_per_seconds = kwargs["position_id_per_seconds"]
|
||||
use_audio_in_video = kwargs.get("use_audio_in_video", False)
|
||||
audio_seqlens = kwargs.get("audio_seqlens", None)
|
||||
second_per_grids = second_per_grid_ts
|
||||
|
||||
mrope_position_deltas = []
|
||||
if input_ids is not None and (
|
||||
image_grid_thw is not None or video_grid_thw is not None
|
||||
):
|
||||
total_input_ids = input_ids
|
||||
position_ids = torch.zeros(
|
||||
3,
|
||||
input_ids.shape[0],
|
||||
input_ids.shape[1],
|
||||
dtype=torch.float,
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_idx, video_idx, audio_idx = 0, 0, 0
|
||||
for i, current_input_ids in enumerate(total_input_ids):
|
||||
image_nums, video_nums, audio_nums = 0, 0, 0
|
||||
vision_start_indices = torch.argwhere(
|
||||
current_input_ids == vision_start_token_id
|
||||
).squeeze(1)
|
||||
if vision_start_indices.numel() > 0:
|
||||
vision_tokens = current_input_ids[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
|
||||
input_tokens = current_input_ids.tolist()
|
||||
llm_pos_ids_list: list = []
|
||||
st = 0
|
||||
remain_images, remain_videos, remain_audios = (
|
||||
image_nums,
|
||||
video_nums,
|
||||
audio_nums,
|
||||
)
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums
|
||||
if use_audio_in_video
|
||||
else image_nums + video_nums + audio_nums
|
||||
)
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1
|
||||
if len(llm_pos_ids_list) > 0
|
||||
else 0
|
||||
)
|
||||
ed_vision_start = (
|
||||
input_tokens.index(vision_start_token_id, st)
|
||||
if (
|
||||
(
|
||||
image_token_id in input_tokens
|
||||
or video_token_id in input_tokens
|
||||
)
|
||||
and (remain_videos > 0 or remain_images > 0)
|
||||
)
|
||||
else len(input_tokens) + 1
|
||||
)
|
||||
ed_audio_start = (
|
||||
input_tokens.index(audio_start_token_id, st)
|
||||
if (audio_token_id in input_tokens and remain_audios > 0)
|
||||
else len(input_tokens) + 1
|
||||
)
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
st_idx += text_len
|
||||
# Audio in Video
|
||||
if (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
):
|
||||
bos_len, eos_len = 2, 2
|
||||
else:
|
||||
bos_len, eos_len = 1, 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
st_idx += bos_len
|
||||
# Audio Only
|
||||
if min_ed == ed_audio_start:
|
||||
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
|
||||
audio_seqlens[audio_idx]
|
||||
)
|
||||
llm_pos_ids = (
|
||||
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
|
||||
st += int(text_len + bos_len + audio_len + eos_len)
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
|
||||
# Image Only
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and current_input_ids[ed_vision_start + 1] == image_token_id
|
||||
):
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t) * 1 * position_id_per_seconds
|
||||
).float()
|
||||
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||
st_idx,
|
||||
image_idx,
|
||||
spatial_merge_size,
|
||||
t_index,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
input_ids.device,
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
|
||||
st += int(text_len + bos_len + image_len + eos_len)
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
|
||||
# Video Only
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and current_input_ids[ed_vision_start + 1] == video_token_id
|
||||
):
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grids[video_idx].cpu().float()
|
||||
* position_id_per_seconds
|
||||
).float()
|
||||
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||
st_idx,
|
||||
video_idx,
|
||||
spatial_merge_size,
|
||||
t_index,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
input_ids.device,
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
|
||||
st += int(text_len + bos_len + video_len + eos_len)
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
|
||||
# Audio in Video
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
):
|
||||
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
|
||||
audio_seqlens[audio_idx]
|
||||
)
|
||||
audio_llm_pos_ids = (
|
||||
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grids[video_idx].cpu().float()
|
||||
* position_id_per_seconds
|
||||
).float()
|
||||
video_llm_pos_ids = (
|
||||
MRotaryEmbedding._get_llm_pos_ids_for_vision(
|
||||
st_idx,
|
||||
video_idx,
|
||||
spatial_merge_size,
|
||||
t_index,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
input_ids.device,
|
||||
)
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if (
|
||||
video_llm_pos_ids[0][video_data_index]
|
||||
<= audio_llm_pos_ids[0][audio_data_index]
|
||||
):
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_data_index + 1
|
||||
]
|
||||
)
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_data_index + 1
|
||||
]
|
||||
)
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
|
||||
st += int(text_len + bos_len + audio_len + video_len + eos_len)
|
||||
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1
|
||||
if len(llm_pos_ids_list) > 0
|
||||
else 0
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1
|
||||
if len(llm_pos_ids_list) > 0
|
||||
else 0
|
||||
)
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(
|
||||
[item.float() for item in llm_pos_ids_list], dim=1
|
||||
).reshape(3, -1)
|
||||
|
||||
position_ids[..., i, :] = llm_positions.to(position_ids.device)
|
||||
mrope_position_deltas.append(
|
||||
llm_positions.max() + 1 - len(current_input_ids)
|
||||
)
|
||||
mrope_position_deltas = torch.tensor(
|
||||
mrope_position_deltas, device=input_ids.device
|
||||
).unsqueeze(1)
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
else:
|
||||
s = input_ids.shape[1]
|
||||
position_ids = torch.arange(s)
|
||||
position_ids = (
|
||||
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
||||
)
|
||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
||||
-1, keepdim=True
|
||||
)[0]
|
||||
mrope_position_deltas = max_position_ids + 1 - s
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
||||
@staticmethod
|
||||
def get_rope_index_glm4v(
|
||||
@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
# For qwen3-omni
|
||||
@staticmethod
|
||||
def _get_feat_extract_output_lengths(input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||
"""
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
# For qwen3-omni
|
||||
@staticmethod
|
||||
def _get_llm_pos_ids_for_vision(
|
||||
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
|
||||
):
|
||||
grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||
grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||
|
||||
h_index = (
|
||||
torch.arange(grid_h, device=device)
|
||||
.view(1, -1, 1)
|
||||
.expand(len(t_index), -1, grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(grid_w, device=device)
|
||||
.view(1, 1, -1)
|
||||
.expand(len(t_index), grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
|
||||
|
||||
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
|
||||
return llm_pos_ids
|
||||
|
||||
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
||||
|
||||
ret_input_ids = input_ids_tensor.tolist()
|
||||
|
||||
return ret_input_ids
|
||||
|
||||
|
||||
@@ -507,7 +506,7 @@ def embed_mm_inputs(
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
use_deepstack: bool = False,
|
||||
use_deepstack: Dict[Modality, bool] = {},
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Embed multimodal inputs and integrate them with text token embeddings.
|
||||
@@ -533,7 +532,9 @@ def embed_mm_inputs(
|
||||
for mm_inputs in mm_inputs_list:
|
||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||
|
||||
embeddings, masks, deepstack_embeddings = [], [], []
|
||||
# deepstack_embeddings: per-modality
|
||||
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# Try get mm embedding if any
|
||||
for modality in Modality.all():
|
||||
@@ -549,7 +550,8 @@ def embed_mm_inputs(
|
||||
# "image", "video", etc
|
||||
modality_id = modality.name.lower()
|
||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||
if len(items) != 0 and embedder is not None:
|
||||
if len(items) != 0:
|
||||
assert embedder is not None, f"no embedding method found for {modality}"
|
||||
placeholder_tensor = torch.as_tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
@@ -580,11 +582,12 @@ def embed_mm_inputs(
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
|
||||
if use_deepstack and embedding is not None:
|
||||
if use_deepstack.get(modality, None) and embedding is not None:
|
||||
embedding, deepstack_embedding = (
|
||||
multimodal_model.separate_deepstack_embeds(embedding)
|
||||
)
|
||||
deepstack_embeddings += [deepstack_embedding]
|
||||
modalities += [modality]
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
@@ -597,17 +600,14 @@ def embed_mm_inputs(
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
# 4. scatter embeddings into input embedding
|
||||
|
||||
# deepstack embedding
|
||||
if use_deepstack:
|
||||
num_deepstack_embeddings = (
|
||||
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
|
||||
)
|
||||
num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
|
||||
|
||||
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
|
||||
inputs_embeds.shape[-1] * num_deepstack_embeddings,
|
||||
)
|
||||
|
||||
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
|
||||
input_deepstack_embeds = torch.zeros(
|
||||
deepstack_embedding_shape,
|
||||
device=inputs_embeds.device,
|
||||
@@ -616,14 +616,16 @@ def embed_mm_inputs(
|
||||
|
||||
other_info["input_deepstack_embeds"] = input_deepstack_embeds
|
||||
|
||||
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
|
||||
# 4. scatter embeddings into input embedding
|
||||
for i, modality, embedding, mask in zip(
|
||||
range(len(embeddings)), modalities, embeddings, masks
|
||||
):
|
||||
if embedding is None or mask is None:
|
||||
continue
|
||||
# in-place update
|
||||
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
|
||||
if use_deepstack:
|
||||
if use_deepstack.get(modality, None):
|
||||
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
@@ -640,7 +642,7 @@ def general_mm_embed_routine(
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
use_deepstack: bool = False,
|
||||
use_deepstack: Dict[Modality, bool] = {},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -652,7 +654,7 @@ def general_mm_embed_routine(
|
||||
language_model: Base language model to use
|
||||
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
||||
placeholder_tokens: Token IDs for multimodal placeholders
|
||||
use_deepstack: Whether to use deepstack embeddings
|
||||
use_deepstack: Whether to use deepstack embeddings for each modality, default False
|
||||
**kwargs: Additional arguments passed to language model
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
)
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
if not isinstance(obj.image_data, list) and obj.image_data:
|
||||
if obj.image_data is not None and not isinstance(obj.image_data, list):
|
||||
obj.image_data = [obj.image_data]
|
||||
if not isinstance(obj.audio_data, list) and obj.audio_data:
|
||||
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
|
||||
obj.audio_data = [obj.audio_data]
|
||||
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
|
||||
@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
config: Qwen3MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer_type=Qwen3MoeDecoderLayer,
|
||||
) -> None:
|
||||
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||
super().__init__(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=Qwen3MoeDecoderLayer,
|
||||
decoder_layer_type=decoder_layer_type,
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
|
||||
661
python/sglang/srt/models/qwen3_omni_moe.py
Normal file
661
python/sglang/srt/models/qwen3_omni_moe.py
Normal file
@@ -0,0 +1,661 @@
|
||||
# Copyright 2025 Qwen Team
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
from sglang.srt.configs.qwen3_omni import (
|
||||
Qwen3OmniMoeAudioEncoderConfig,
|
||||
Qwen3OmniMoeThinkerConfig,
|
||||
Qwen3OmniMoeVisionEncoderConfig,
|
||||
)
|
||||
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
|
||||
from sglang.srt.models.qwen3_vl_moe import (
|
||||
Qwen3MoeLLMModel,
|
||||
Qwen3VLMoeForConditionalGeneration,
|
||||
load_fused_expert_weights,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix, logger
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeAudioEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = config.d_model
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = VisionAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
projection_size=embed_dim,
|
||||
use_qkv_parallel=True,
|
||||
rotary_embed="normal",
|
||||
proj_bias=True,
|
||||
qkv_backend="fa3",
|
||||
softmax_in_single_precision=False,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
x=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(
|
||||
hidden_states, min=-clamp_value, max=clamp_value
|
||||
)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SinusoidsPositionEmbedding(nn.Module):
|
||||
def __init__(self, length, channels, max_timescale=10000):
|
||||
super().__init__()
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(
|
||||
-log_timescale_increment * torch.arange(channels // 2).float()
|
||||
)
|
||||
scaled_time = (
|
||||
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
)
|
||||
self.register_buffer(
|
||||
"positional_embedding",
|
||||
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, seqlen: int):
|
||||
return self.positional_embedding[:seqlen, :]
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||
"""
|
||||
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
|
||||
config: Qwen3OmniMoeAudioEncoderConfig
|
||||
|
||||
def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
|
||||
embed_dim = config.d_model
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.n_window = config.n_window
|
||||
self.positional_embedding = SinusoidsPositionEmbedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen3OmniMoeAudioEncoderLayer(config)
|
||||
for _ in range(config.encoder_layers)
|
||||
]
|
||||
)
|
||||
self.ln_post = nn.LayerNorm(config.d_model)
|
||||
self.gradient_checkpointing = False
|
||||
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
||||
self.conv2d2 = nn.Conv2d(
|
||||
config.downsample_hidden_size,
|
||||
config.downsample_hidden_size,
|
||||
3,
|
||||
2,
|
||||
padding=1,
|
||||
)
|
||||
self.conv2d3 = nn.Conv2d(
|
||||
config.downsample_hidden_size,
|
||||
config.downsample_hidden_size,
|
||||
3,
|
||||
2,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_out = nn.Linear(
|
||||
config.downsample_hidden_size
|
||||
* ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
|
||||
config.d_model,
|
||||
bias=False,
|
||||
)
|
||||
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
||||
self.n_window_infer = self.config.n_window_infer
|
||||
self.conv_chunksize = self.config.conv_chunksize
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self._requires_grad = False
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.conv1
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.conv1 = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
feature_lens=None,
|
||||
aftercnn_lens=None,
|
||||
):
|
||||
r"""
|
||||
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
||||
mel length
|
||||
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
||||
mel length after cnn
|
||||
"""
|
||||
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
|
||||
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
||||
|
||||
chunk_lengths = torch.tensor(
|
||||
[self.n_window * 2] * chunk_num.sum(),
|
||||
dtype=torch.long,
|
||||
device=feature_lens.device,
|
||||
)
|
||||
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
||||
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
||||
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
||||
|
||||
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
||||
padded_feature = nn.utils.rnn.pad_sequence(
|
||||
chunk_list, batch_first=True
|
||||
).transpose(1, 2)
|
||||
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
|
||||
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
|
||||
[
|
||||
torch.ones(length, dtype=torch.bool, device=padded_feature.device)
|
||||
for length in feature_lens_after_cnn
|
||||
],
|
||||
batch_first=True,
|
||||
)
|
||||
padded_feature = padded_feature.unsqueeze(1)
|
||||
# Split to chunk to avoid OOM during convolution
|
||||
padded_embeds = []
|
||||
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
||||
padded_embed = F.gelu(self.conv2d1(chunk))
|
||||
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
||||
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
||||
padded_embeds.append(padded_embed)
|
||||
padded_embed = torch.cat(padded_embeds, dim=0)
|
||||
b, c, f, t = padded_embed.size()
|
||||
padded_embed = self.conv_out(
|
||||
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
|
||||
)
|
||||
|
||||
positional_embedding = (
|
||||
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
||||
.unsqueeze(0)
|
||||
.to(padded_embed.dtype)
|
||||
)
|
||||
padded_embed = padded_embed + positional_embedding
|
||||
hidden_states = padded_embed[padded_mask_after_cnn]
|
||||
cu_chunk_lens = [0]
|
||||
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
|
||||
self.n_window_infer // (self.n_window * 2)
|
||||
)
|
||||
for cnn_len in aftercnn_lens:
|
||||
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
|
||||
remainder = cnn_len % window_aftercnn
|
||||
if remainder != 0:
|
||||
cu_chunk_lens += [remainder]
|
||||
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
|
||||
-1, dtype=torch.int32
|
||||
)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
hidden_states = self.proj1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.proj2(hidden_states)
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
# Ignore copy
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||
"""
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
return input_lengths, output_lengths
|
||||
|
||||
|
||||
class Qwen3OmniMoeVisionPatchMerger(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_postshuffle_norm=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
self.ln_q = RMSNorm(
|
||||
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
|
||||
)
|
||||
self.mlp = nn.ModuleList(
|
||||
[
|
||||
ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.0", prefix),
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
self.hidden_size,
|
||||
dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.2", prefix),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = (
|
||||
x.view(-1, self.hidden_size)
|
||||
if self.use_postshuffle_norm
|
||||
else x.view(-1, x.shape[-1])
|
||||
)
|
||||
hidden = self.ln_q(x).view(-1, self.hidden_size)
|
||||
for layer in self.mlp:
|
||||
if isinstance(hidden, tuple):
|
||||
hidden = hidden[0]
|
||||
hidden = layer(hidden)
|
||||
|
||||
if isinstance(hidden, tuple):
|
||||
hidden = hidden[0]
|
||||
|
||||
return hidden
|
||||
|
||||
|
||||
class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
|
||||
config: Qwen3OmniMoeVisionEncoderConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeVisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vision_config=config,
|
||||
quant_config=quant_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
)
|
||||
|
||||
self.merger = Qwen3OmniMoeVisionPatchMerger(
|
||||
dim=config.out_hidden_size,
|
||||
context_dim=config.hidden_size,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
use_postshuffle_norm=False,
|
||||
prefix=add_prefix("merger", prefix),
|
||||
)
|
||||
self.merger_list = nn.ModuleList(
|
||||
[
|
||||
Qwen3OmniMoeVisionPatchMerger(
|
||||
dim=config.out_hidden_size,
|
||||
context_dim=config.hidden_size,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
use_postshuffle_norm=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("merger_list", prefix),
|
||||
)
|
||||
for _ in range(len(config.deepstack_visual_indexes))
|
||||
]
|
||||
)
|
||||
del self.deepstack_merger_list
|
||||
|
||||
@property
|
||||
def deepstack_merger_list(self):
|
||||
return self.merger_list
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.patch_embed.proj.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.patch_embed.proj.weight.device
|
||||
|
||||
|
||||
class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
|
||||
config: Qwen3OmniMoeThinkerConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeThinkerConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
|
||||
)
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
|
||||
self.visual = Qwen3OmniMoeVisionEncoder(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
self.pad_token_id = (
|
||||
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def get_audio_feature(self, items: List[MultimodalDataItem]):
|
||||
feature_attention_mask = torch.cat(
|
||||
[item.feature_attention_mask for item in items], dim=0
|
||||
).type(torch.long)
|
||||
input_features = (
|
||||
torch.cat([item.feature for item in items])
|
||||
.type(self.audio_tower.dtype)
|
||||
.to(next(self.audio_tower.parameters()).device)
|
||||
)
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
input_features = input_features.permute(0, 2, 1)[
|
||||
feature_attention_mask.bool()
|
||||
].permute(1, 0)
|
||||
else:
|
||||
audio_feature_lengths = None
|
||||
|
||||
feature_lens = (
|
||||
audio_feature_lengths
|
||||
if audio_feature_lengths is not None
|
||||
else feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features,
|
||||
feature_lens=feature_lens,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
|
||||
return audio_features
|
||||
|
||||
|
||||
class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3VLMoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
config.thinker_config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
self.enable_talker = False
|
||||
self.pad_input_ids = self.thinker.pad_input_ids
|
||||
self.forward = self.thinker.forward
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
ignore_suffixes = (
|
||||
".bias",
|
||||
"_bias",
|
||||
".k_scale",
|
||||
"_k_scale",
|
||||
".v_scale",
|
||||
"_v_scale",
|
||||
".weight_scale",
|
||||
"_weight_scale",
|
||||
".input_scale",
|
||||
"_input_scale",
|
||||
)
|
||||
|
||||
is_fused_expert = False
|
||||
fused_expert_params_mapping = [
|
||||
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||
]
|
||||
|
||||
num_experts = self.config.num_experts
|
||||
|
||||
# Cache params_dict to avoid repeated expensive traversal of model parameters
|
||||
if not hasattr(self, "_cached_params_dict"):
|
||||
self._cached_params_dict = dict(self.named_parameters())
|
||||
params_dict = self._cached_params_dict
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace(r"model.language_model.", r"model.")
|
||||
|
||||
if ("talker" in name or "code2wav" in name) and not self.enable_talker:
|
||||
continue
|
||||
|
||||
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
is_fused_expert = True
|
||||
expert_params_mapping = fused_expert_params_mapping
|
||||
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "visual" in name:
|
||||
continue
|
||||
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
|
||||
# if is_pp_missing_parameter(name, self):
|
||||
# continue
|
||||
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Track if this is an expert weight to enable early skipping
|
||||
is_expert_weight = False
|
||||
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "visual" in name or "audio_tower" in name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight[0],
|
||||
"w1",
|
||||
num_experts,
|
||||
)
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight[1],
|
||||
"w3",
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if (
|
||||
name_mapped.endswith(ignore_suffixes)
|
||||
and name_mapped not in params_dict
|
||||
):
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# # other available replicas.
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# This is an expert weight but not mapped to this rank, skip all remaining processing
|
||||
continue
|
||||
if "visual" in name or "audio_tower" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
name = name.replace(r"model.visual.", r"visual.")
|
||||
name = name.replace(r"attn.out_proj.", r"attn.proj.")
|
||||
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Loaded weight with {name=} not found in params_dict"
|
||||
)
|
||||
|
||||
|
||||
EntryClass = Qwen3OmniMoeForConditionalGeneration
|
||||
@@ -15,7 +15,7 @@
|
||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||
import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionRotaryEmbedding,
|
||||
)
|
||||
|
||||
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
|
||||
from sglang.srt.configs.qwen3_vl import (
|
||||
Qwen3VLConfig,
|
||||
Qwen3VLTextConfig,
|
||||
Qwen3VLVisionConfig,
|
||||
)
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
||||
from sglang.srt.models.qwen3 import Qwen3Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# === Vision Encoder === #
|
||||
|
||||
|
||||
@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3_VisionPatchMerger(nn.Module):
|
||||
class Qwen3VLMoeVisionPatchMerger(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class Qwen3_VisionTransformer(nn.Module):
|
||||
class Qwen3VLMoeVisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||
# layer indexes of which layer's output should be deep-stacked
|
||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
|
||||
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
for layer_idx in range(vision_config.depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen3_VisionPatchMerger(
|
||||
self.merger = Qwen3VLMoeVisionPatchMerger(
|
||||
dim=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
self.deepstack_merger_list = nn.ModuleList(
|
||||
[
|
||||
Qwen3_VisionPatchMerger(
|
||||
Qwen3VLMoeVisionPatchMerger(
|
||||
dim=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
||||
config: Qwen3VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
language_model_cls=Qwen3LLMModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
self.visual = Qwen3VLMoeVisionModel(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||
quant_config=quant_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
self.model = Qwen3LLMModel(
|
||||
config=config,
|
||||
# TODO: make it more elegant
|
||||
if language_model_cls is Qwen3LLMModel:
|
||||
self.config: Qwen3VLConfig = config # for qwen3-vl
|
||||
else:
|
||||
self.config = config.text_config # for qwen3-omni
|
||||
|
||||
self.model = language_model_cls(
|
||||
config=self.config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.logits_processor = LogitsProcessor(self.config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
|
||||
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
|
||||
@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
|
||||
# deepstack
|
||||
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
||||
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
||||
|
||||
@property
|
||||
def use_deepstack(self) -> bool:
|
||||
return hasattr(self, "deepstack_visual_indexes")
|
||||
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
|
||||
|
||||
def separate_deepstack_embeds(self, embedding):
|
||||
assert (
|
||||
|
||||
@@ -14,29 +14,19 @@
|
||||
# ==============================================================================
|
||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||
import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionRotaryEmbedding,
|
||||
)
|
||||
|
||||
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
|
||||
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||
@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
||||
from sglang.srt.models.qwen3_vl import (
|
||||
Qwen3_VisionTransformer,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Qwen3VLMoeConfig,
|
||||
config: Qwen3VLMoeTextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.embed_tokens
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
)
|
||||
|
||||
# process deepstack
|
||||
if input_deepstack_embeds is not None and layer_idx in range(3):
|
||||
if input_deepstack_embeds is not None and layer_idx < 3:
|
||||
sep = self.hidden_size * layer_idx
|
||||
hidden_states.add_(
|
||||
input_deepstack_embeds[:, sep : sep + self.hidden_size]
|
||||
@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
|
||||
def load_fused_expert_weights(
|
||||
name: str,
|
||||
params_dict: dict,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_id: str,
|
||||
num_experts: int,
|
||||
):
|
||||
param = params_dict[name]
|
||||
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||
weight_loader = param.weight_loader
|
||||
ep_rank = get_tensor_model_parallel_rank()
|
||||
ep_size = get_moe_expert_parallel_world_size()
|
||||
if ep_size == 1:
|
||||
for expert_id in range(num_experts):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
weight_loader(
|
||||
param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
)
|
||||
else:
|
||||
experts_per_ep = num_experts // ep_size
|
||||
start_expert = ep_rank * experts_per_ep
|
||||
end_expert = (
|
||||
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
|
||||
)
|
||||
|
||||
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
weight_loader(
|
||||
param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
idx,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Qwen3VLMoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
language_model_cls=Qwen3MoeLLMModel,
|
||||
):
|
||||
super(Qwen3VLForConditionalGeneration, self).__init__()
|
||||
self.config = config
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
self.model = Qwen3MoeLLMModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
# deepstack
|
||||
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
|
||||
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
|
||||
|
||||
@property
|
||||
def use_deepstack(self) -> bool:
|
||||
return hasattr(self, "deepstack_visual_indexes")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
"""Run forward pass for Qwen3-VL.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
positions: Flattened (concatenated) position ids corresponding to a
|
||||
batch.
|
||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||
opensource models), the shape will be `(3, seq_len)`,
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
"""
|
||||
if self.is_mrope_enabled:
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
if self.is_mrope_enabled:
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.model,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
use_deepstack=self.use_deepstack,
|
||||
)
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_fused_expert_weights(
|
||||
self,
|
||||
name: str,
|
||||
params_dict: dict,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_id: str,
|
||||
num_experts: int,
|
||||
):
|
||||
param = params_dict[name]
|
||||
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||
weight_loader = param.weight_loader
|
||||
ep_rank = get_tensor_model_parallel_rank()
|
||||
ep_size = get_moe_expert_parallel_world_size()
|
||||
if ep_size == 1:
|
||||
for expert_id in range(num_experts):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
weight_loader(
|
||||
param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
)
|
||||
else:
|
||||
experts_per_ep = num_experts // ep_size
|
||||
start_expert = ep_rank * experts_per_ep
|
||||
end_expert = (
|
||||
(ep_rank + 1) * experts_per_ep
|
||||
if ep_rank != ep_size - 1
|
||||
else num_experts
|
||||
)
|
||||
|
||||
for idx, expert_id in enumerate(range(start_expert, end_expert)):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
weight_loader(
|
||||
param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
idx,
|
||||
)
|
||||
return True
|
||||
super().__init__(config, quant_config, prefix, language_model_cls)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self._cached_params_dict = dict(self.named_parameters())
|
||||
params_dict = self._cached_params_dict
|
||||
for name, loaded_weight in weights:
|
||||
if "language_model" in name:
|
||||
name = name.replace(r"model.language_model.", r"model.")
|
||||
name = name.replace(r"model.language_model.", r"model.")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
self.load_fused_expert_weights(
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight[0],
|
||||
"w1",
|
||||
num_experts,
|
||||
)
|
||||
self.load_fused_expert_weights(
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight[1],
|
||||
@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
self.load_fused_expert_weights(
|
||||
load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
loaded_weight,
|
||||
|
||||
@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
self.transport_mode = transport_mode
|
||||
|
||||
@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
"input_features": Modality.AUDIO,
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
"audio_attention_mask": Modality.AUDIO,
|
||||
"feature_attention_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"second_per_grid_ts": Modality.VIDEO,
|
||||
@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
if self._processor.__class__.__name__ in {
|
||||
"Gemma3nProcessor",
|
||||
"Qwen2AudioProcessor",
|
||||
"Qwen3OmniMoeProcessor",
|
||||
}:
|
||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||
kwargs["audio"] = audios
|
||||
|
||||
@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
|
||||
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
@@ -209,22 +210,31 @@ async def preprocess_video(
|
||||
return video
|
||||
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
# Compatible with Qwen-VL & Qwen-Omni Series
|
||||
class QwenVLImageProcessor(SGLangBaseProcessor):
|
||||
models = [
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMoeForConditionalGeneration,
|
||||
Qwen3OmniMoeForConditionalGeneration,
|
||||
]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
self.model_type = hf_config.model_type
|
||||
if hf_config.model_type == "qwen3_omni_moe":
|
||||
hf_config = hf_config.thinker_config
|
||||
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# The regex that matches expanded image tokens.
|
||||
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||
|
||||
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
|
||||
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
|
||||
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
self.IMAGE_FACTOR = 28
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
image_token_id=hf_config.image_token_id,
|
||||
# The regex that matches expanded image tokens.
|
||||
image_token_regex=re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
),
|
||||
video_token_id=hf_config.video_token_id,
|
||||
audio_token_id=self.audio_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
video_data=request_obj.video_data,
|
||||
audio_data=request_obj.audio_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
audio_feature_lengths = None
|
||||
|
||||
if self.model_type == "qwen3_omni_moe":
|
||||
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
|
||||
if audio_item:
|
||||
audio_feature_lengths = torch.sum(
|
||||
audio_item.feature_attention_mask, dim=1
|
||||
)
|
||||
|
||||
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
|
||||
ret, "video_second_per_grid", None
|
||||
)
|
||||
|
||||
input_ids = input_ids.flatten()
|
||||
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.mm_tokens.image_token_id,
|
||||
video_token_id=self.mm_tokens.video_token_id,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
model_type=self.hf_config.model_type,
|
||||
model_type=self.model_type,
|
||||
tokens_per_second=getattr(
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
||||
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
||||
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
use_audio_in_video=False,
|
||||
audio_seqlens=audio_feature_lengths,
|
||||
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
|
||||
audio_start_token_id=self.audio_start_token_id,
|
||||
position_id_per_seconds=getattr(
|
||||
self.hf_config, "position_id_per_seconds", None
|
||||
),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"video_token_id": self.mm_tokens.video_token_id,
|
||||
"audio_token_id": self.mm_tokens.audio_token_id,
|
||||
"mrope_positions": mrope_positions,
|
||||
"mrope_position_delta": mrope_position_delta,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user