[Misc] Clean up uesless code in rotary_embedding (#2663)

Clean up useless code which is only used for torchair in rotary_embedding

- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-09-02 17:25:33 +08:00
committed by GitHub
parent 253b01b9a5
commit c1e607b7b7
2 changed files with 32 additions and 49 deletions

View File

@@ -19,7 +19,6 @@ import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@@ -28,19 +27,18 @@ from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
def custom_rotary_embedding_enabled(query, neox_style, head_size):
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
)
def rope_forward_oot(
def _rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
@@ -51,8 +49,8 @@ def rope_forward_oot(
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
if _custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
query,
@@ -80,23 +78,6 @@ def rope_forward_oot(
return query.view(query_shape), key.view(key_shape)
def set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
class AscendRotaryEmbedding(RotaryEmbedding):
def __init__(
@@ -118,13 +99,15 @@ class AscendRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
@@ -328,6 +311,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe