diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 7fba3b06..d699ec7d 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -18,7 +18,6 @@ import math from typing import Optional, Tuple -import einops import torch import torch_npu from vllm.model_executor.layers.rotary_embedding import ( @@ -32,7 +31,8 @@ if HAS_TRITON: from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type, has_rope, is_vl_model) + get_ascend_device_type, has_rope, is_vl_model, + vllm_version_is) # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. @@ -622,14 +622,20 @@ class AscendApplyRotaryEmb(ApplyRotaryEmb): cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: + if vllm_version_is('0.13.0'): + origin_shape = x.shape + origin_dtype = x.dtype + if len(origin_shape) == 3: + x = x.unsqueeze(0) + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + else: + x, cos, sin, origin_shape, origin_dtype = self._pre_process( + x, cos, sin) + head_dim = x.shape[-1] - - origin_dtype = x.dtype - if self.enable_fp32_compute: - x = x.float() - cos = cos.float() - sin = sin.float() - # cos, sin: [seq_len, head_dim // 2] cos = torch.cat((cos, cos), dim=-1) sin = torch.cat((sin, sin), dim=-1) @@ -637,22 +643,14 @@ class AscendApplyRotaryEmb(ApplyRotaryEmb): cos = cos.reshape(1, -1, 1, head_dim) sin = sin.reshape(1, -1, 1, head_dim) - if len(x.shape) == 3: - # x: [seq_len, num_heads, head_size] - x = x.unsqueeze(0) - # x: [1, seq_len, num_heads, head_size] - output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0) - else: - assert len(x.shape) == 4 - # x: [2 * b, s, head, head_dim] - qk = einops.rearrange( - x, "(two b) s head head_dim -> b s two head head_dim", two=2) - # q, k: [b, s, head, head_dim] - q, k = qk[:, :, 0], qk[:, :, 1] - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - output = torch.cat([q, k], dim=0) + output = torch_npu.npu_rotary_mul(x, cos, sin) + + if vllm_version_is('0.13.0'): + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + else: + output = self._post_process(output, origin_shape, origin_dtype) - if self.enable_fp32_compute: - output = output.to(origin_dtype) return output