diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 5b0daa3..89e2bc7 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -64,6 +64,29 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: + if self.rotary_dim < self.head_size: + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + torch_npu._npu_rotary_embedding( + positions, + q_rot, + k_rot, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + return q, k # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1)