From ad13964c7121d7d80813c6f79a0b5fce9b6f66b0 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Mon, 1 Sep 2025 09:10:15 +0800 Subject: [PATCH] [6/N][refactor]delete torchair in rotary ops (#2581) ### What this PR does / why we need it? After moved torchair related rope ops into torchair_ops, split the torchair from the origin rope ops to make the code clean. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: https://github.com/vllm-project/vllm/commit/ab9f2cfd1942f7ddfee658ce86ea96b4789862af - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/81eea3d348c26fb1e6ff0185ad109aedd60a28a2 Signed-off-by: hust17yixuan <303660421@qq.com> --- tests/ut/ops/test_rotary_embedding.py | 36 ++---------------- vllm_ascend/ops/rotary_embedding.py | 54 ++------------------------- 2 files changed, 7 insertions(+), 83 deletions(-) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index e8bb918..b129f02 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -88,36 +88,16 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.mock_self.cos_sin_cache = self.cos_sin_cache self.mock_self.is_neox_style = self.is_neox_style - @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') - def test_rope_forward_oot_torchair_enabled_base(self, - mock_get_ascend_config): - # Setup mock for torchair enabled - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = True - mock_get_ascend_config.return_value = mock_config - with patch.object(self.layer, - "forward_native", - return_value=(self.query, - self.key)) as mock_forward_native: - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) - - mock_forward_native.assert_called_once() - self.assertTrue(torch.equal(result_q, self.query)) - self.assertTrue(torch.equal(result_k, self.key)) - @patch('torch.ops._C') - @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', return_value=True) @patch('torch.ops._npu_rotary_embedding') def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, mock_custom_enabled, mock_is_310p, - mock_get_ascend_config, mock__c): + mock__c): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config # Setup mock for custom kernel path @@ -130,16 +110,13 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.assertEqual(result_q.shape, self.query.shape) self.assertEqual(result_k.shape, self.key.shape) - @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') def test_rope_forward_oot_contiguous(self, mock_npu_rotary, - mock_custom_enabled, - mock_get_ascend_config): + mock_custom_enabled): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config # Test contiguous path when custom is disabled non_contig_query = self.query.transpose(0, 1) @@ -153,27 +130,22 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.assertEqual(result_q.shape, non_contig_query.shape) self.assertEqual(result_k.shape, non_contig_key.shape) - @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') - def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config): + def test_rope_forward_oot_with_offsets(self): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config # Test that NotImplementedError is raised when offsets is provided offsets = torch.tensor([1, 2, 3]) with self.assertRaises(NotImplementedError): self.layer.forward(self.positions, self.query, self.key, offsets) - @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, - mock_custom_enabled, - mock_get_ascend_config): + mock_custom_enabled): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config # Test neox_style override result_q, result_k = self.layer.forward(self.positions, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 079c356..9dc472f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -24,7 +24,6 @@ import torch_npu from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import enable_custom_op, is_310p @@ -43,15 +42,6 @@ def rope_forward_oot( is_neox_style_override: Optional[bool] = None, is_qwen_torchair: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_ascend_config( - ).torchair_graph_config.enabled and not is_qwen_torchair: - return self.forward_native( - positions, - query, - key, - offsets, - ) - query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) @@ -120,11 +110,6 @@ class AscendRotaryEmbedding(RotaryEmbedding): ) -> None: super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - if get_ascend_config().torchair_graph_config.enabled: - set_cos_sin_cache(self, - seq_len=max_position_embeddings, - device="npu", - dtype=dtype) def forward_oot( self, @@ -137,42 +122,9 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_prefill: Optional[bool] = True, is_qwen_torchair: Optional[bool] = False, ): - if get_ascend_config().torchair_graph_config.enabled \ - and is_qwen_torchair and not is_prefill: - if max_seq_len is not None and torch.gt( - max_seq_len, self.max_position_embeddings): - set_cos_sin_cache(self, - seq_len=max_seq_len, - device=query.device, - dtype=torch.float32) - - # bsnd/bnsd - if positions is not None: - cos = self.embed(positions, self.cos) - sin = self.embed(positions, self.sin) - self.cos_embed = cos - self.sin_embed = sin - else: - cos = self.cos_embed - sin = self.sin_embed - - query = query.view(*query.shape[:-1], -1, - self.head_size).contiguous() - key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() - - cos = cos.unsqueeze(-2).unsqueeze(-2) - sin = sin.unsqueeze(-2).unsqueeze(-2) - - query = query.unsqueeze(1) - key = key.unsqueeze(1) - - q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( - query, key, cos, sin) - return q_embed.flatten(-2), k_embed.flatten(-2) - else: - return rope_forward_oot(self, positions, query, key, offsets, - is_neox_style_override, - is_qwen_torchair) # type: ignore + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):