[6/N][refactor]delete torchair in rotary ops (#2581)

### What this PR does / why we need it?
After moved torchair related rope ops into torchair_ops, split the
torchair from the origin rope ops to make the code clean.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: main
vLLM main:
ab9f2cfd19


- vLLM version: v0.10.1.1
- vLLM main:
81eea3d348

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-09-01 09:10:15 +08:00
committed by GitHub
parent c2c97f3079
commit ad13964c71
2 changed files with 7 additions and 83 deletions

View File

@@ -88,36 +88,16 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
mock_get_ascend_config):
# Setup mock for torchair enabled
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
with patch.object(self.layer,
"forward_native",
return_value=(self.query,
self.key)) as mock_forward_native:
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock_forward_native.assert_called_once()
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C')
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock_get_ascend_config, mock__c):
mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Setup mock for custom kernel path
@@ -130,16 +110,13 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
@@ -153,27 +130,22 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = self.layer.forward(self.positions,

View File

@@ -24,7 +24,6 @@ import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
@@ -43,15 +42,6 @@ def rope_forward_oot(
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config(
).torchair_graph_config.enabled and not is_qwen_torchair:
return self.forward_native(
positions,
query,
key,
offsets,
)
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -120,11 +110,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def forward_oot(
self,
@@ -137,42 +122,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(
max_seq_len, self.max_position_embeddings):
set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1,
self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):