[CustomOp] Register AscendMMEncoderAttention CustomOp and remove related patch (#4750)
### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/30125, register
`AscendMMEncoderAttention` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
✅ Run Qwen2.5-VL:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
✅ Run Qwen3-VL:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
109
vllm_ascend/ops/mm_encoder_attention.py
Normal file
109
vllm_ascend/ops/mm_encoder_attention.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of attention heads per partition.
|
||||
head_size: hidden_size per attention head.
|
||||
scale: scale factor.
|
||||
num_kv_heads: number of kv heads.
|
||||
prefix: This has no effect, it is only here to make it easier to
|
||||
swap between Attention and MMEncoderAttention.
|
||||
multimodal_config: configs for multi-modal.
|
||||
"""
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
prefix=prefix,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor
|
||||
| None = None, # Only used for Flash Attention
|
||||
):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
|
||||
and self.head_size > MIN_PAD_SIZE
|
||||
and self.head_size < MAX_PAD_SIZE)
|
||||
|
||||
if enable_pad:
|
||||
origin_shape = q.shape[-1]
|
||||
pad_len = MAX_PAD_SIZE - origin_shape
|
||||
# q, k, v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
|
||||
q = F.pad(q, (0, pad_len), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad_len), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad_len), mode="constant", value=0)
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.head_size**-0.5,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=context_layer,
|
||||
)
|
||||
|
||||
if enable_pad:
|
||||
context_layer = context_layer[..., :origin_shape]
|
||||
|
||||
context_layer = einops.rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=bsz).contiguous()
|
||||
return context_layer
|
||||
@@ -18,7 +18,6 @@
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionAttention,
|
||||
Qwen2_5_VLImageInputs,
|
||||
@@ -26,7 +25,6 @@ from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionAttention,
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention
|
||||
from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
@@ -47,7 +45,6 @@ class AscendQwen2_5_VisionAttention(nn.Module):
|
||||
x, _ = self.qkv(x)
|
||||
seq_len, batch_size, _ = x.shape
|
||||
|
||||
# Split q k v.
|
||||
qkv = einops.rearrange(
|
||||
x,
|
||||
"s b (three head head_dim) -> b s three head head_dim",
|
||||
@@ -55,10 +52,6 @@ class AscendQwen2_5_VisionAttention(nn.Module):
|
||||
head=self.num_attention_heads_per_partition,
|
||||
)
|
||||
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
origin_shape = q.shape[-1]
|
||||
|
||||
# Convert cumulative tensor to intervals and move it to cpu.
|
||||
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
|
||||
|
||||
cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
|
||||
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
|
||||
@@ -67,43 +60,14 @@ class AscendQwen2_5_VisionAttention(nn.Module):
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
q, k, v = [
|
||||
einops.rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
|
||||
and self.hidden_size_per_attention_head > MIN_PAD_SIZE
|
||||
and self.hidden_size_per_attention_head < MAX_PAD_SIZE)
|
||||
|
||||
if enable_pad:
|
||||
pad_len = MAX_PAD_SIZE - origin_shape
|
||||
# q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
|
||||
q = F.pad(q, (0, pad_len), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad_len), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad_len), mode="constant", value=0)
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
if enable_pad:
|
||||
context_layer = context_layer[..., :origin_shape]
|
||||
|
||||
context_layer = einops.rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
@@ -651,6 +651,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
|
||||
@@ -679,6 +680,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"FusedMoE": AscendFusedMoE,
|
||||
"SharedFusedMoE": AscendSharedFusedMoE,
|
||||
"MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention,
|
||||
"MMEncoderAttention": AscendMMEncoderAttention,
|
||||
}
|
||||
|
||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||
|
||||
Reference in New Issue
Block a user