[CustomOp] Register AscendApplyRotaryEmb CustomOp and remove related patch (#4667)
### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### ✅ Test Qwen2.5-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio": null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```
#### ✅ Test Qwen3-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -104,6 +104,6 @@ class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
context_layer = context_layer[..., :origin_shape]
|
||||
|
||||
context_layer = einops.rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
"(b s) h d -> b s h d",
|
||||
b=bsz).contiguous()
|
||||
return context_layer
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
@@ -524,3 +526,59 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
rotary_mode='half')
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
class AscendApplyRotaryEmb(ApplyRotaryEmb):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enforce_enable: bool = False,
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
enforce_enable=enforce_enable,
|
||||
is_neox_style=is_neox_style,
|
||||
enable_fp32_compute=enable_fp32_compute,
|
||||
)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
head_dim = x.shape[-1]
|
||||
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
# cos, sin: [seq_len, head_dim // 2]
|
||||
cos = torch.cat((cos, cos), dim=-1)
|
||||
sin = torch.cat((sin, sin), dim=-1)
|
||||
# cos, sin: [1, seq_len, 1, head_dim]
|
||||
cos = cos.reshape(1, -1, 1, head_dim)
|
||||
sin = sin.reshape(1, -1, 1, head_dim)
|
||||
|
||||
if len(x.shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
# x: [1, seq_len, num_heads, head_size]
|
||||
output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0)
|
||||
else:
|
||||
assert len(x.shape) == 4
|
||||
# x: [2 * b, s, head, head_dim]
|
||||
qk = einops.rearrange(
|
||||
x, "(two b) s head head_dim -> b s two head head_dim", two=2)
|
||||
# q, k: [b, s, head, head_dim]
|
||||
q, k = qk[:, :, 0], qk[:, :, 1]
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
output = torch.cat([q, k], dim=0)
|
||||
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
@@ -160,53 +160,7 @@
|
||||
# Future Plan:
|
||||
# Identify this pattern in torch-npu and remove this patch.
|
||||
#
|
||||
# ** 5. File: worker/patch_qwen2_5_omni.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen2_5_omni_thinker.Qwen2_5OmniThinkerForConditionalGeneration`
|
||||
# Why:
|
||||
# we have ascend forward context which doesn't work with upstream.
|
||||
# How:
|
||||
# override forward_context in the model file
|
||||
# Related PR (if no, explain why):
|
||||
# This is a bug by Ascend only. we should drop set_ascend_forward_context
|
||||
# Future Plan:
|
||||
# Remove this patch once forward_context is refactor.
|
||||
#
|
||||
# ** 6. File: worker/patch_qwen2_5_vl.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration`
|
||||
# Why:
|
||||
# we have ascend forward context which doesn't work with upstream.
|
||||
# How:
|
||||
# override forward_context in the model file
|
||||
# Related PR (if no, explain why):
|
||||
# This is a bug by Ascend only. we should drop set_ascend_forward_context
|
||||
# Future Plan:
|
||||
# Remove this patch once forward_context is refactor.
|
||||
#
|
||||
# 2. `vllm.model_executor.models.qwen2_vl.Qwen2VisionAttention.forward`
|
||||
# Why:
|
||||
# the attention is not custom ops
|
||||
# How:
|
||||
# make it to custom ops and pluggable
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/30125
|
||||
# Future Plan:
|
||||
# Remove this patch one the PR is merged into vLLM.
|
||||
#
|
||||
# ** 7. File: worker/patch_qwen3_vl.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_vl.Qwen3_VisionTransformer.forward`
|
||||
# Why:
|
||||
# the attention is not custom ops
|
||||
# How:
|
||||
# make it to custom ops and pluggable
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/30125
|
||||
# Future Plan:
|
||||
# Remove this patch one the PR is merged into vLLM.
|
||||
#
|
||||
# ** 8. File: worker/patch_roberta.py **
|
||||
# ** 5. File: worker/patch_roberta.py **
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.bert `
|
||||
# Why:
|
||||
@@ -218,7 +172,7 @@
|
||||
# Future Plan:
|
||||
# Revert this when CANN support shift aclnn operation
|
||||
#
|
||||
# ** 9. File: worker/patch_triton.py**
|
||||
# ** 6. File: worker/patch_triton.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.mamba.ops`, `vllm.model_executor.layers.fla.ops`
|
||||
# Why:
|
||||
@@ -230,7 +184,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM support the dispatch function.
|
||||
#
|
||||
# ** 10. File: worker/patch_weight_loader.py**
|
||||
# ** 7. File: worker/patch_weight_loader.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.linear.UnquantizedLinearMethod`
|
||||
# Why:
|
||||
@@ -242,7 +196,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when the bug is fixed.
|
||||
#
|
||||
# ** 11. File: worker/patch_qwen3_next_mtp.py**
|
||||
# ** 8. File: worker/patch_qwen3_next_mtp.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
||||
# Why:
|
||||
@@ -255,7 +209,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||
#
|
||||
# ** 12. File: worker/patch_module.py**
|
||||
# ** 9. File: worker/patch_module.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||
# Why:
|
||||
@@ -271,7 +225,7 @@
|
||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||
#
|
||||
# ** 13. File: worker/patch_rejection_sampler.py**
|
||||
# ** 10. File: worker/patch_rejection_sampler.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.sample.rejection_sampler`
|
||||
# Why:
|
||||
@@ -287,7 +241,7 @@
|
||||
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
|
||||
# 2. make these functions as costom op, then remove AscendRejectionSampler
|
||||
#
|
||||
# ** 14.File: worker/patch_qwen3_next.py**
|
||||
# ** 11.File: worker/patch_qwen3_next.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet.forward`
|
||||
# Why:
|
||||
@@ -299,7 +253,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM support these operators.
|
||||
#
|
||||
# ** 15. File: worker/patch_qwen3_next.py**
|
||||
# ** 12. File: worker/patch_qwen3_next.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core`
|
||||
# Why:
|
||||
|
||||
@@ -28,8 +28,6 @@ import vllm_ascend.patch.worker.patch_deepseek # noqa
|
||||
import vllm_ascend.patch.worker.patch_weight_loader # noqa
|
||||
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.worker.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
|
||||
import vllm_ascend.patch.worker.patch_rope # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.models.qwen2_5_omni_thinker import (
|
||||
Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs)
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
|
||||
|
||||
class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module):
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"].type(self.visual.dtype)
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: Qwen2_5_VLVideoInputs,
|
||||
video_hashes: list[str] | None = None,
|
||||
cached_video_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if video_input["type"] == "video_embeds":
|
||||
return video_input["video_embeds"].type(self.visual.dtype)
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
|
||||
return video_embeds.split(sizes.tolist())
|
||||
@@ -1,135 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionAttention,
|
||||
Qwen2_5_VLImageInputs,
|
||||
Qwen2_5_VLVideoInputs)
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention
|
||||
from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
seq_len, batch_size, _ = x.shape
|
||||
|
||||
qkv = einops.rearrange(
|
||||
x,
|
||||
"s b (three head head_dim) -> b s three head head_dim",
|
||||
three=3,
|
||||
head=self.num_attention_heads_per_partition,
|
||||
)
|
||||
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
|
||||
cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
|
||||
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
|
||||
cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head)
|
||||
sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head)
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values,
|
||||
grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return image_embeds.split(sizes)
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos,
|
||||
grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
|
||||
# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm.
|
||||
Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
|
||||
Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
|
||||
@@ -666,8 +666,9 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
|
||||
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding,
|
||||
AscendMRotaryEmbedding, AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding)
|
||||
@@ -694,6 +695,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"SharedFusedMoE": AscendSharedFusedMoE,
|
||||
"MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention,
|
||||
"MMEncoderAttention": AscendMMEncoderAttention,
|
||||
"ApplyRotaryEmb": AscendApplyRotaryEmb,
|
||||
}
|
||||
|
||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||
|
||||
Reference in New Issue
Block a user