Files
xc-llm-ascend/vllm_ascend/models/qwen2_5_vl_without_padding.py
Li Wang ad366bf908 [Bugfix] Follow vLLM Qwen-Moe/VL and KV Connector change to fix broken CI (#2181)
### What this PR does / why we need it?
This pr fix broken CI:
1. Fix the
ee2eb6ecd8
changes, in this commit, they fused the gate and up projections in the
vision MLP, This can improve performance by reducing one matrix
multiplication. so, this pr do the following things:
- Specify that the two linear layers are fused as `mlp.gate_up_proj`
when loading the weights.
    - Use a SiluAndMul activation function.
2. Fix
aefeea0fde,
Update ModelRunnerOutput parameters to adapt to its changes
3. Fix
[vllm-commit](https://github.com/vllm-project/vllm/pull/20815/files#diff-3ffb829a39ab2b3e4706aa28f5e476815f36c3a87b98d6a66514ebedc8f3ffb4R354-R356),
fix qwen moe
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
fed5849d3f

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2025-08-04 21:37:50 +08:00

378 lines
15 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
get_act_and_mul_fn)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
Qwen2_5_VLProcessingInfo)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
from vllm_ascend.utils import vllm_version_is
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
embed_dim,
num_heads,
projection_size,
quant_config,
prefix,
)
self.embed_dim = embed_dim
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
q = torch_npu.npu_rotary_mul(q, cos, sin)
k = torch_npu.npu_rotary_mul(k, cos, sin)
q, k, v = [
rearrange(x, "b s h d -> (b s) h d").contiguous()
for x in (q, k, v)
]
context_layer = torch.empty_like(q)
# operator requires pta version >= 2.5.1.dev20250226
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=cu_seqlens,
scale_value=self.hidden_size_per_attention_head**-0.5,
num_heads=self.num_attention_heads_per_partition,
num_kv_heads=self.num_attention_heads_per_partition,
out=context_layer)
context_layer = rearrange(context_layer,
"(b s) h d -> s b (h d)",
b=batch_size).contiguous()
output, _ = self.proj(context_layer)
return output
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix)
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
x = x + self.mlp(self.norm2(x))
return x
class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.matmul(
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
return x
class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
):
def __init__(
self,
vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
interleaved=False,
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix)
norm_layer = partial(RMSNorm, eps=norm_eps)
self.interleaved = interleaved
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
2)
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
patch_size=vision_config.patch_size,
temporal_patch_size=vision_config.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
if vllm_version_is("0.10.0"):
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
self.blocks = nn.ModuleList([
AscendQwen2_5_VisionBlock_Without_Padding(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=act_fn,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(vision_config.depth)
])
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
self.hidden_size, self.num_heads)
def cal_cos_sin(self, rotary_pos_emb):
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
sin = rotary_pos_emb.sin()
if not self.interleaved:
cos_new = torch.cat((cos, cos), dim=-1)
sin_new = torch.cat((sin, sin), dim=-1)
else:
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
"... d two -> ...(d two)",
two=2)
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
"... d two -> ...(d two)",
two=2)
cos_new = cos_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
sin_new = sin_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
return cos_new, sin_new
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (self.window_size //
self.spatial_merge_size // self.patch_size)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h = grid_h // self.spatial_merge_size
llm_grid_w = grid_w // self.spatial_merge_size
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
index_padded = index_padded.reshape(grid_t, num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
vit_merger_window_size)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:,
0]).cpu().to(torch.int32)
# patchify
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# windows attention
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=x.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
seq_len, _ = x.size()
x = x.reshape(seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
x = x[window_index, :, :]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
cos, sin = self.cal_cos_sin(rotary_pos_emb)
# transformers
x = x.unsqueeze(1)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
# adapter
x = self.merger(x)
reverse_indices = torch.argsort(window_index)
x = x[reverse_indices, :]
return x
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,
info=Qwen2_5_VLProcessingInfo,
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
Qwen2_5_VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)