[Refact.]: Refactor some leftover implementations of 300I DUO in the main branch. (#6425)

### What this PR does / why we need it?
- Replace the RoPE operator implementation.
- Refactor some leftover implementations of 300I DUO in the main branch.

### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
Shaoxu Cheng
2026-02-02 16:12:04 +08:00
committed by GitHub
parent eeedf7c503
commit 460ea88276
7 changed files with 94 additions and 23 deletions

View File

@@ -18,13 +18,13 @@
import torch
import torch.nn.functional as F
from vllm_ascend.ops.activation import AscendSiluAndMul as _Base
from vllm_ascend.ops.activation import AscendSiluAndMul
class AscendSiluAndMul310(_Base):
class AscendSiluAndMul310(AscendSiluAndMul):
def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
h = x.shape[-1] // 2
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
out = F.silu(x[..., :h]) * x[..., h:]
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out

View File

@@ -11,17 +11,16 @@ class AscendRMSNorm310(AscendRMSNorm):
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is not None:
orig_dtype = residual.dtype
if x is None or x.numel() == 0 or x.shape[-1] == 0:
x = residual.to(dtype=residual.dtype)
x = residual
else:
x = x + residual.to(x.dtype)
x = x + residual
residual = x.to(orig_dtype)
residual = x
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x

View File

@@ -19,10 +19,10 @@ import einops
import torch
import torch_npu
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
class AscendMMEncoderAttention310(_Base):
class AscendMMEncoderAttention310(AscendMMEncoderAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -15,9 +15,85 @@
# This file is a part of the vllm-ascend project.
#
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
import torch
import torch_npu
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding, get_cos_and_sin_slice
class AscendMRotaryEmbedding310(MRotaryEmbedding):
def forward_oot(self, positions, query, key):
return super().forward_oot(positions, query, key)
def _rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
is_neox_style: bool,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
if offsets is not None:
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
rotary_mode = "half" if is_neox_style else "interleave"
if self.head_size == 128 and self.cos_sin_cache.shape[-1] == 128:
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin, rotary_mode=rotary_mode)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]
if self.rotary_dim == 64:
q_rot = q_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim)
k_rot = k_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim)
q_rot, k_rot = torch_npu.npu_apply_rotary_pos_emb(q_rot, k_rot, cos, sin, rotary_mode=rotary_mode)
else:
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.rotary_dim,
self.cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
else:
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding310(AscendRotaryEmbedding):
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: torch.Tensor | None = None,
is_neox_style_override: bool | None = None,
):
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)

View File

@@ -26,7 +26,8 @@ class NPUWorker310(NPUWorker):
def init_device(self):
self.device = self._init_device()
torch_npu.npu.set_compile_mode(jit_compile=False)
# TODO: There is accuracy issue when jit_compile is disabled currently.
torch_npu.npu.set_compile_mode(jit_compile=True)
init_workspace_manager(self.device, num_ubatches=1)

View File

@@ -628,13 +628,13 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
from vllm_ascend._310p.ops.rotary_embedding import AscendMRotaryEmbedding310
from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310
REGISTERED_ASCEND_OPS.update(
{
"SiluAndMul": AscendSiluAndMul310,
"MMEncoderAttention": AscendMMEncoderAttention310,
"MRotaryEmbedding": AscendMRotaryEmbedding310,
"RotaryEmbedding": AscendRotaryEmbedding310,
"RMSNorm": AscendRMSNorm310,
"GemmaRMSNorm": AscendGemmaRMSNorm310,
}

View File

@@ -149,9 +149,6 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
if get_ascend_device_type() == AscendDeviceType._310P:
torch_npu.npu.set_compile_mode(jit_compile=False)
SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144
@@ -2527,9 +2524,7 @@ class NPUModelRunner(GPUModelRunner):
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
if get_ascend_device_type() == AscendDeviceType._310P:
k_cache = maybe_trans_nz(k_cache)
v_cache = maybe_trans_nz(v_cache)
if self.use_sparse and raw_dsa_k_tensor is not None:
dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128)
dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize