Files
sglang/python/sglang/srt/models/qwen3_vl.py
2025-10-12 20:25:30 +08:00

785 lines
29 KiB
Python

# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__)
# === Vision Encoder === #
class Qwen3_VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
bias: bool = True,
hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
)
self.linear_fc2 = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
)
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor):
x_fc1, _ = self.linear_fc1(x)
mlp_output, _ = self.linear_fc2(self.act(x_fc1))
return mlp_output
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(
self.in_channels,
self.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1,
self.in_channels,
self.temporal_patch_size,
self.patch_size,
self.patch_size,
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
-1, self.embed_dim
)
return hidden_states
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
intermediate_dim: int,
hidden_act="silu",
norm_layer: Optional[Callable[[int], nn.Module]] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
if attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.mlp = Qwen3_VisionMLP(
dim,
intermediate_dim,
hidden_act=hidden_act,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x += attn
norm2 = self.norm2(x)
mlp = self.mlp(norm2)
x += mlp
return x
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(
self.hidden_size if use_postshuffle_norm else context_dim
)
self.linear_fc1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.num_position_embeddings = vision_config.num_position_embeddings
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
intermediate_dim=vision_config.intermediate_size,
hidden_act=vision_config.hidden_act,
norm_layer=norm_layer,
attn_implementation="flash_attention_3",
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
)
for layer_idx in range(vision_config.depth)
]
)
self.merger = Qwen3_VisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=add_prefix("merger", prefix),
)
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
# TODO: use torch instand of np
for t, h, w in grid_thw:
h_idxs = np.linspace(0, num_grid_per_side - 1, h)
w_idxs = np.linspace(0, num_grid_per_side - 1, w)
h_idxs_floor = h_idxs.astype(int)
w_idxs_floor = w_idxs.astype(int)
h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
idx_list[0].extend(
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
.flatten()
.tolist()
* t
)
idx_list[1].extend(
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
.flatten()
.tolist()
* t
)
idx_list[2].extend(
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
.flatten()
.tolist()
* t
)
idx_list[3].extend(
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
.flatten()
.tolist()
* t
)
weight_list[0].extend(
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
)
weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype
p0 = (
self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
* torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
)
p1 = (
self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
* torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
)
p2 = (
self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
* torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
)
p3 = (
self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
* torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
)
patch_pos_embeds = p0 + p1 + p2 + p3
patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
patch_pos_embeds_permute = []
m_size = self.spatial_merge_size
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
pos_embed = (
pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
x += pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len, _ = x.size()
rotary_pos_emb = rotary_pos_emb.to(x.device)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1)
deepstack_feature_lists = []
num_deepstack_captured = 0
for layer_num, blk in enumerate(self.blocks):
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
x
)
deepstack_feature_lists.append(deepstack_feature)
num_deepstack_captured += 1
x = self.merger(x)
hidden_states = torch.cat(
[x] + deepstack_feature_lists, dim=1
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
cached_get_processor = lru_cache(get_processor)
class Qwen3LLMModel(Qwen3Model):
def __init__(
self,
*,
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
if not self.pp_group.is_first_rank:
assert self.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
self.hidden_size = config.hidden_size
self.deepstack_embed_to_decoder_layer = range(
len(config.vision_config.deepstack_visual_indexes)
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_deepstack_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
):
layer_idx = layer_idx + self.start_layer
if layer_idx in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
# process deepstack
if (
input_deepstack_embeds is not None
and layer_idx in self.deepstack_embed_to_decoder_layer
):
sep = self.hidden_size * layer_idx
hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size]
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Qwen3VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3LLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def separate_deepstack_embeds(self, embedding):
assert (
embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
separate_index = self.config.hidden_size
input_embeds = embedding[:, :separate_index]
input_deepstack_embeds = embedding[:, separate_index:]
return input_embeds, input_deepstack_embeds
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "visual" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = Qwen3VLForConditionalGeneration