[6/N][refactor]delete torchair in rotary ops (#2581)

### What this PR does / why we need it?
After moved torchair related rope ops into torchair_ops, split the
torchair from the origin rope ops to make the code clean.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: main
vLLM main:
ab9f2cfd19


- vLLM version: v0.10.1.1
- vLLM main:
81eea3d348

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-09-01 09:10:15 +08:00
committed by GitHub
parent c2c97f3079
commit ad13964c71
2 changed files with 7 additions and 83 deletions

View File

@@ -24,7 +24,6 @@ import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
@@ -43,15 +42,6 @@ def rope_forward_oot(
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config(
).torchair_graph_config.enabled and not is_qwen_torchair:
return self.forward_native(
positions,
query,
key,
offsets,
)
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -120,11 +110,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def forward_oot(
self,
@@ -137,42 +122,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(
max_seq_len, self.max_position_embeddings):
set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1,
self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):