[MM][Model] Remove Qwen2-VL modeling files (#4534)

### What this PR does / why we need it?

Following https://github.com/vllm-project/vllm-ascend/pull/4349, remove
Qwen2-VL modeling files.


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-11-29 18:07:01 +08:00
committed by GitHub
parent 6664a4e5ce
commit 2a19215e5f
4 changed files with 212 additions and 590 deletions

View File

@@ -24,17 +24,27 @@ import torch.nn.functional as F
import torch_npu
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \
Qwen2_5_VLVisionConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import \
Qwen2VLVisionConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch, dispatch_rotary_emb_function)
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImageInputs,
Qwen2_5_VLVideoInputs)
from vllm.model_executor.models.qwen2_vl import (Qwen2VisionAttention,
Qwen2VisionBlock,
Qwen2VisionPatchEmbed,
Qwen2VisionPatchMerger,
Qwen2VisionTransformer)
from vllm.model_executor.models.utils import cast_overflow_tensors
from vllm.model_executor.models.vision import (
get_vit_attn_backend, run_dp_sharded_mrope_vision_model)
@@ -74,18 +84,8 @@ class AscendQwen2_5_VisionAttention(nn.Module):
# Convert cumulative tensor to intervals and move it to cpu.
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
cos = rotary_pos_emb_cos
sin = rotary_pos_emb_sin
cos = einops.rearrange(
torch.stack((cos, cos), dim=-1),
"... d two -> ...(d two)",
two=2,
)
sin = einops.rearrange(
torch.stack((sin, sin), dim=-1),
"... d two -> ...(d two)",
two=2,
)
cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head)
sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head)
q = torch_npu.npu_rotary_mul(q, cos, sin)
@@ -132,6 +132,191 @@ class AscendQwen2_5_VisionAttention(nn.Module):
return output
class AscendQwen2VisionBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
return x
class AscendQwen2VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
nn.Module.__init__(self)
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
spatial_merge_size = vision_config.spatial_merge_size
in_channels = vision_config.in_channels
hidden_size = vision_config.hidden_size
embed_dim = vision_config.embed_dim
depth = vision_config.depth
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
self.embed_dim = embed_dim
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList([
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
) for layer_idx in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
def rot_pos_emb(
self,
grid_thw: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
pos_ids = []
max_grid_size = 0
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten())
wpos_ids = (wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten())
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
max_grid_size = max(max_grid_size, h, w)
pos_ids = torch.cat(pos_ids, dim=0)
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
# (num_tokens, rotary_dim // 2)
cos_h = cos[pos_ids[:, 0]] # type: ignore
cos_w = cos[pos_ids[:, 1]] # type: ignore
sin_h = sin[pos_ids[:, 0]] # type: ignore
sin_w = sin[pos_ids[:, 1]] # type: ignore
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()
# compute position embedding
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(
grid_thw_list)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# transformers
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter
x = self.merger(x)
return x
class AscendQwen2_5_VisionBlock(nn.Module):
def forward(
@@ -486,7 +671,16 @@ class AscendQwen2_5_VLForConditionalGeneration(nn.Module):
return video_embeds.split(sizes)
def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(
default=partial(apply_rotary_emb_torch, is_neox_style=True))
output = rotary_emb_function(t, cos, sin).type_as(t)
return output
# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm.
Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
# NOTE: These will be removed after https://github.com/vllm-project/vllm/pull/29388 is merged.
@@ -494,8 +688,13 @@ Qwen2_5_VLForConditionalGeneration._process_image_input = AscendQwen2_5_VLForCon
Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input
# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main.
Qwen2VisionBlock.forward = AscendQwen2VisionBlock.forward
Qwen2VisionTransformer.__init__ = AscendQwen2VisionTransformer.__init__
Qwen2VisionTransformer.rot_pos_emb = AscendQwen2VisionTransformer.rot_pos_emb
Qwen2VisionTransformer.forward = AscendQwen2VisionTransformer.forward
Qwen2_5_VisionBlock.forward = AscendQwen2_5_VisionBlock.forward
Qwen2_5_VisionTransformer.__init__ = AscendQwen2_5_VisionTransformer.__init__
Qwen2_5_VisionTransformer.rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer.rotary_pos_emb_thw
Qwen2_5_VisionTransformer.get_rope_by_thw = AscendQwen2_5_VisionTransformer.get_rope_by_thw
Qwen2_5_VisionTransformer.forward = AscendQwen2_5_VisionTransformer.forward
apply_rotary_pos_emb_vision = _apply_rotary_pos_emb_vision