[Misc] Clean up uesless code in rotary_embedding (#2663)
Clean up useless code which is only used for torchair in rotary_embedding
- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -19,7 +19,6 @@ import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
@@ -28,19 +27,18 @@ from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
def _rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
@@ -51,8 +49,8 @@ def rope_forward_oot(
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
if _custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
@@ -80,23 +78,6 @@ def rope_forward_oot(
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
||||
(1 / self.rotary_dim)))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
||||
self.embed = F.embedding
|
||||
|
||||
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
@@ -118,13 +99,15 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
is_prefill: Optional[bool] = True,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
):
|
||||
return rope_forward_oot(self, positions, query, key, offsets,
|
||||
is_neox_style_override,
|
||||
is_qwen_torchair) # type: ignore
|
||||
return _rope_forward_oot(
|
||||
self,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
is_neox_style_override,
|
||||
)
|
||||
|
||||
|
||||
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
@@ -328,6 +311,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
Reference in New Issue
Block a user