### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/30125, register
`AscendMMEncoderAttention` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
✅ Run Qwen2.5-VL:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
✅ Run Qwen3-VL:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
136 lines
5.1 KiB
Python
136 lines
5.1 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import einops
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionAttention,
|
|
Qwen2_5_VLImageInputs,
|
|
Qwen2_5_VLVideoInputs)
|
|
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention
|
|
from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model
|
|
|
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
|
|
|
MIN_PAD_SIZE = 64 # min_size to pad weight
|
|
MAX_PAD_SIZE = 128 # max_size to pad weight
|
|
|
|
|
|
class AscendQwen2_5_VisionAttention(nn.Module):
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
rotary_pos_emb_cos: torch.Tensor,
|
|
rotary_pos_emb_sin: torch.Tensor,
|
|
max_seqlen: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
|
x, _ = self.qkv(x)
|
|
seq_len, batch_size, _ = x.shape
|
|
|
|
qkv = einops.rearrange(
|
|
x,
|
|
"s b (three head head_dim) -> b s three head head_dim",
|
|
three=3,
|
|
head=self.num_attention_heads_per_partition,
|
|
)
|
|
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
|
|
|
cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
|
|
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
|
|
cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head)
|
|
sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head)
|
|
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
|
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
|
|
|
context_layer = self.attn(
|
|
query=q,
|
|
key=k,
|
|
value=v,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
|
|
output, _ = self.proj(context_layer)
|
|
return output
|
|
|
|
|
|
class AscendQwen2_5_VLForConditionalGeneration(nn.Module):
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
|
|
grid_thw = image_input["image_grid_thw"]
|
|
assert grid_thw.ndim == 2
|
|
grid_thw_list = grid_thw.tolist()
|
|
|
|
if image_input["type"] == "image_embeds":
|
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
|
else:
|
|
pixel_values = image_input["pixel_values"]
|
|
with set_ascend_forward_context(None, self.vllm_config):
|
|
if self.use_data_parallel:
|
|
return run_dp_sharded_mrope_vision_model(
|
|
self.visual,
|
|
pixel_values,
|
|
grid_thw_list,
|
|
rope_type="rope_3d")
|
|
else:
|
|
image_embeds = self.visual(pixel_values,
|
|
grid_thw=grid_thw_list)
|
|
|
|
# Split concatenated embeddings for each image item.
|
|
merge_size = self.visual.spatial_merge_size
|
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
|
return image_embeds.split(sizes)
|
|
|
|
def _process_video_input(
|
|
self,
|
|
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
|
|
grid_thw = video_input["video_grid_thw"]
|
|
assert grid_thw.ndim == 2
|
|
grid_thw_list = grid_thw.tolist()
|
|
|
|
if video_input["type"] == "video_embeds":
|
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
|
else:
|
|
pixel_values_videos = video_input["pixel_values_videos"]
|
|
with set_ascend_forward_context(None, self.vllm_config):
|
|
if self.use_data_parallel:
|
|
return run_dp_sharded_mrope_vision_model(
|
|
self.visual,
|
|
pixel_values_videos,
|
|
grid_thw_list,
|
|
rope_type="rope_3d",
|
|
)
|
|
else:
|
|
video_embeds = self.visual(pixel_values_videos,
|
|
grid_thw=grid_thw_list)
|
|
|
|
# Split concatenated embeddings for each video item.
|
|
merge_size = self.visual.spatial_merge_size
|
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
|
return video_embeds.split(sizes)
|
|
|
|
|
|
# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm.
|
|
Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
|
|
Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
|