diff --git a/vllm_ascend/_310p/ops/activation.py b/vllm_ascend/_310p/ops/activation.py index 73d409cf..241a955f 100644 --- a/vllm_ascend/_310p/ops/activation.py +++ b/vllm_ascend/_310p/ops/activation.py @@ -18,13 +18,13 @@ import torch import torch.nn.functional as F -from vllm_ascend.ops.activation import AscendSiluAndMul as _Base +from vllm_ascend.ops.activation import AscendSiluAndMul -class AscendSiluAndMul310(_Base): +class AscendSiluAndMul310(AscendSiluAndMul): def forward(self, x: torch.Tensor) -> torch.Tensor: torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) h = x.shape[-1] // 2 - out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16) + out = F.silu(x[..., :h]) * x[..., h:] torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/_310p/ops/layernorm.py b/vllm_ascend/_310p/ops/layernorm.py index d1b4978c..f8220d65 100644 --- a/vllm_ascend/_310p/ops/layernorm.py +++ b/vllm_ascend/_310p/ops/layernorm.py @@ -11,17 +11,16 @@ class AscendRMSNorm310(AscendRMSNorm): residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is not None: - orig_dtype = residual.dtype if x is None or x.numel() == 0 or x.shape[-1] == 0: - x = residual.to(dtype=residual.dtype) + x = residual else: - x = x + residual.to(x.dtype) + x = x + residual - residual = x.to(orig_dtype) + residual = x x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) return x, residual - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) if self.bias is not None: x.add_(self.bias) return x diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py index 97481879..27a9cfcf 100644 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ b/vllm_ascend/_310p/ops/mm_encoder_attention.py @@ -19,10 +19,10 @@ import einops import torch import torch_npu -from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base +from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention -class AscendMMEncoderAttention310(_Base): +class AscendMMEncoderAttention310(AscendMMEncoderAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/vllm_ascend/_310p/ops/rotary_embedding.py b/vllm_ascend/_310p/ops/rotary_embedding.py index f51d27fd..9ea6f9c1 100644 --- a/vllm_ascend/_310p/ops/rotary_embedding.py +++ b/vllm_ascend/_310p/ops/rotary_embedding.py @@ -15,9 +15,85 @@ # This file is a part of the vllm-ascend project. # -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding + +import torch +import torch_npu + +from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding, get_cos_and_sin_slice -class AscendMRotaryEmbedding310(MRotaryEmbedding): - def forward_oot(self, positions, query, key): - return super().forward_oot(positions, query, key) +def _rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + is_neox_style: bool, + offsets: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + query_shape, key_shape = query.shape, key.shape + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + cos, sin = get_cos_and_sin_slice() + if offsets is not None: + raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.") + rotary_mode = "half" if is_neox_style else "interleave" + if self.head_size == 128 and self.cos_sin_cache.shape[-1] == 128: + query = query.contiguous().view(1, query.shape[0], -1, self.head_size) + key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin, rotary_mode=rotary_mode) + elif self.rotary_dim < self.head_size: + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + q_rot = query[..., : self.rotary_dim] + q_pass = query[..., self.rotary_dim :] + k_rot = key[..., : self.rotary_dim] + k_pass = key[..., self.rotary_dim :] + if self.rotary_dim == 64: + q_rot = q_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim) + k_rot = k_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim) + q_rot, k_rot = torch_npu.npu_apply_rotary_pos_emb(q_rot, k_rot, cos, sin, rotary_mode=rotary_mode) + else: + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + torch_npu._npu_rotary_embedding( + positions, + q_rot, + k_rot, + self.rotary_dim, + self.cos_sin_cache, + is_neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + else: + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + return query.view(query_shape), key.view(key_shape) + + +class AscendRotaryEmbedding310(AscendRotaryEmbedding): + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: torch.Tensor | None = None, + is_neox_style_override: bool | None = None, + ): + is_neox_style = self.is_neox_style + if is_neox_style_override is not None: + is_neox_style = is_neox_style_override + return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) diff --git a/vllm_ascend/_310p/worker_310p.py b/vllm_ascend/_310p/worker_310p.py index e8615b34..8ced752b 100644 --- a/vllm_ascend/_310p/worker_310p.py +++ b/vllm_ascend/_310p/worker_310p.py @@ -26,7 +26,8 @@ class NPUWorker310(NPUWorker): def init_device(self): self.device = self._init_device() - torch_npu.npu.set_compile_mode(jit_compile=False) + # TODO: There is accuracy issue when jit_compile is disabled currently. + torch_npu.npu.set_compile_mode(jit_compile=True) init_workspace_manager(self.device, num_ubatches=1) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f45d1440..9aadfb66 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -628,13 +628,13 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 - from vllm_ascend._310p.ops.rotary_embedding import AscendMRotaryEmbedding310 + from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310 REGISTERED_ASCEND_OPS.update( { "SiluAndMul": AscendSiluAndMul310, "MMEncoderAttention": AscendMMEncoderAttention310, - "MRotaryEmbedding": AscendMRotaryEmbedding310, + "RotaryEmbedding": AscendRotaryEmbedding310, "RMSNorm": AscendRMSNorm310, "GemmaRMSNorm": AscendGemmaRMSNorm310, } diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f7718ba8..b4aacea3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -149,9 +149,6 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict -if get_ascend_device_type() == AscendDeviceType._310P: - torch_npu.npu.set_compile_mode(jit_compile=False) - SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144 @@ -2527,9 +2524,7 @@ class NPUModelRunner(GPUModelRunner): ] k_cache = raw_k_tensor.view(dtype).view(k_shape) v_cache = raw_v_tensor.view(dtype).view(v_shape) - if get_ascend_device_type() == AscendDeviceType._310P: - k_cache = maybe_trans_nz(k_cache) - v_cache = maybe_trans_nz(v_cache) + if self.use_sparse and raw_dsa_k_tensor is not None: dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128) dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize