[MM][Model] Remove Qwen3-VL modeling files (#4577)
### What this PR does / why we need it? Following https://github.com/vllm-project/vllm-ascend/pull/4349, remove Qwen3-VL modeling files. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
This commit is contained in:
@@ -2,14 +2,6 @@ from vllm import ModelRegistry
|
||||
|
||||
|
||||
def register_model():
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration")
|
||||
|
||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||
ModelRegistry.register_model(
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
|
||||
Qwen3VLConfig
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
|
||||
Qwen3VLMoeConfig
|
||||
except ImportError:
|
||||
pass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
|
||||
|
||||
try:
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
|
||||
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
from vllm.model_executor.models.qwen3_vl_moe import (
|
||||
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
|
||||
except ImportError:
|
||||
Qwen3_VisionBlock = object
|
||||
Qwen3_VisionPatchEmbed = object
|
||||
Qwen3_VisionTransformer = object
|
||||
Qwen3VLDummyInputsBuilder = object
|
||||
Qwen3VLForConditionalGeneration = object
|
||||
Qwen3VLMultiModalProcessor = object
|
||||
Qwen3VLProcessingInfo = object
|
||||
Qwen3VLMoeForConditionalGeneration = object
|
||||
Qwen3VLMoeProcessingInfo = object
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
x = x + self.proj.bias
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix, use_data_parallel)
|
||||
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix,
|
||||
use_data_parallel)
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
self.patch_embed = AscendQwen3_VisionPatchEmbed(
|
||||
patch_size=self.patch_size,
|
||||
temporal_patch_size=self.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLMoeForConditionalGeneration(
|
||||
Qwen3VLMoeForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
@@ -29,4 +29,5 @@ import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.worker.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_rope # noqa
|
||||
|
||||
@@ -65,7 +65,7 @@ class AscendQwen2_5_VisionAttention(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
seqlens: torch.Tensor,
|
||||
seqlens: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
251
vllm_ascend/patch/worker/patch_qwen3_vl.py
Normal file
251
vllm_ascend/patch/worker/patch_qwen3_vl.py
Normal file
@@ -0,0 +1,251 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
|
||||
Qwen3VLVisionConfig
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock,
|
||||
Qwen3_VisionPatchEmbed,
|
||||
Qwen3_VisionPatchMerger,
|
||||
Qwen3_VisionTransformer)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
|
||||
|
||||
class AscendQwen3_VisionBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen3VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.num_position_embeddings = vision_config.num_position_embeddings
|
||||
self.patch_size = vision_config.patch_size
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
self.spatial_merge_unit = self.spatial_merge_size**2
|
||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||
|
||||
# NOTE: This is used for creating empty tensor for all_gather for
|
||||
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
||||
self.out_hidden_size = vision_config.out_hidden_size * (
|
||||
1 + len(self.deepstack_visual_indexes))
|
||||
|
||||
self.patch_embed = Qwen3_VisionPatchEmbed(
|
||||
patch_size=self.patch_size,
|
||||
temporal_patch_size=self.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
self.pos_embed = nn.Embedding(self.num_position_embeddings,
|
||||
self.hidden_size)
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim // 2,
|
||||
max_position=8192,
|
||||
base=10000.0,
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
self.merger = Qwen3_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
norm_layer=norm_layer,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
Qwen3_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
use_postshuffle_norm=True,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
) for layer_idx in range(len(self.deepstack_visual_indexes))
|
||||
])
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
use_upstream_fa = False
|
||||
if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now.")
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
) for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||
pos_ids = [
|
||||
self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else
|
||||
self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
|
||||
for t, h, w in grid_thw
|
||||
]
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
# (num_tokens, rotary_dim // 2)
|
||||
cos_h = cos[pos_ids[:, 0]] # type: ignore
|
||||
cos_w = cos[pos_ids[:, 1]] # type: ignore
|
||||
sin_h = sin[pos_ids[:, 0]] # type: ignore
|
||||
sin_w = sin[pos_ids[:, 1]] # type: ignore
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
|
||||
return cos_combined, sin_combined
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device,
|
||||
dtype=self.dtype,
|
||||
non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||
else:
|
||||
grid_thw = grid_thw.to("cpu")
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_thw = grid_thw.numpy()
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(
|
||||
grid_thw_list)
|
||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device,
|
||||
non_blocking=True)
|
||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device,
|
||||
non_blocking=True)
|
||||
|
||||
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
|
||||
# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main.
|
||||
Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward
|
||||
Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__
|
||||
Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb
|
||||
Qwen3_VisionTransformer.forward = AscendQwen3_VisionTransformer.forward
|
||||
Reference in New Issue
Block a user