This reverts commit 646c1db5d7.
this new ops may lead accuracy problem
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
This commit is contained in:
@@ -22,7 +22,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -395,37 +395,3 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
||||
is_neox_style, offsets)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
):
|
||||
if self.mrope_section != [16, 24, 24]:
|
||||
return super().forward_oot(positions, query, key)
|
||||
|
||||
import torch_npu
|
||||
mrope_section = [0, 0, 0
|
||||
] if positions.ndim == 1 else self.mrope_section
|
||||
|
||||
if self.cos_sin_cache.device != query.device: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.device) # type: ignore
|
||||
|
||||
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.dtype) # type: ignore
|
||||
|
||||
query, key = torch_npu.npu_mrope(positions,
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
self.cos_sin_cache.contiguous(),
|
||||
self.head_size,
|
||||
mrope_section=mrope_section,
|
||||
rotary_mode='half')
|
||||
|
||||
return query, key
|
||||
|
||||
Reference in New Issue
Block a user