diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index b129f02..eb48c81 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase -from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled +from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): @@ -31,37 +31,37 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): # Test when all conditions are True with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', return_value=True): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size) + result = _custom_rotary_embedding_enabled(self.query, True, + self.head_size) self.assertTrue(result) # Test when dtype is not float16 with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', return_value=True): query = self.query.to(torch.float32) - result = custom_rotary_embedding_enabled(query, True, - self.head_size) + result = _custom_rotary_embedding_enabled(query, True, + self.head_size) self.assertFalse(result) # Test when neox_style is False with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', return_value=True): - result = custom_rotary_embedding_enabled(self.query, False, - self.head_size) + result = _custom_rotary_embedding_enabled(self.query, False, + self.head_size) self.assertFalse(result) # Test when head_size is not divisible by 32 with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', return_value=True): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size + 1) + result = _custom_rotary_embedding_enabled(self.query, True, + self.head_size + 1) self.assertFalse(result) # Test when custom op is disabled with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', return_value=False): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size) + result = _custom_rotary_embedding_enabled(self.query, True, + self.head_size) self.assertFalse(result) @@ -90,7 +90,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('torch.ops._C') @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) - @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=True) @patch('torch.ops._npu_rotary_embedding') def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, @@ -110,7 +110,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.assertEqual(result_q.shape, self.query.shape) self.assertEqual(result_k.shape, self.key.shape) - @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') def test_rope_forward_oot_contiguous(self, mock_npu_rotary, @@ -139,7 +139,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase): with self.assertRaises(NotImplementedError): self.layer.forward(self.positions, self.query, self.key, offsets) - @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, @@ -198,7 +198,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): def test_native_rope_deepseek_forward_base(self, mock_npuplatform): mock_npuplatform.device_type = torch.device("cpu") self.layer = self._create_layer() - with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot", + with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot", return_value=(self.query, self.key)) as mock_rope_forward_oot: q_pe, k_pe = self.layer.forward(self.positions, self.query, @@ -207,7 +207,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): assert q_pe.shape == self.query.shape assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') @patch("vllm.platforms.current_platform.device_type", new=torch.device("cpu")) @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", @@ -229,7 +229,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): assert q_pe.shape == self.query.shape assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') @patch("vllm.platforms.current_platform.device_type", new=torch.device("cpu")) @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", @@ -248,7 +248,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): assert q_pe.shape == self.query.shape assert k_pe.shape == key.shape - @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') @patch("vllm.platforms.current_platform.device_type", new=torch.device("cpu")) @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 9dc472f..5b0daa3 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -19,7 +19,6 @@ import math from typing import Optional, Tuple import torch -import torch.nn.functional as F import torch_npu from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -28,19 +27,18 @@ from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import enable_custom_op, is_310p -def custom_rotary_embedding_enabled(query, neox_style, head_size): +def _custom_rotary_embedding_enabled(query, neox_style, head_size): return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( ) -def rope_forward_oot( +def _rope_forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None, - is_qwen_torchair: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: @@ -51,8 +49,8 @@ def rope_forward_oot( if is_neox_style_override is not None: neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if custom_rotary_embedding_enabled(query, neox_style, - self.head_size) and not is_310p(): + if _custom_rotary_embedding_enabled(query, neox_style, + self.head_size) and not is_310p(): query, key = torch.ops._C.rotary_embedding( positions, query, @@ -80,23 +78,6 @@ def rope_forward_oot( return query.view(query_shape), key.view(key_shape) -def set_cos_sin_cache(self, seq_len, device, dtype): - inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * - (1 / self.rotary_dim))) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) - self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) - self.embed = F.embedding - - class AscendRotaryEmbedding(RotaryEmbedding): def __init__( @@ -118,13 +99,15 @@ class AscendRotaryEmbedding(RotaryEmbedding): key: torch.Tensor, offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None, - max_seq_len: Optional[int] = None, - is_prefill: Optional[bool] = True, - is_qwen_torchair: Optional[bool] = False, ): - return rope_forward_oot(self, positions, query, key, offsets, - is_neox_style_override, - is_qwen_torchair) # type: ignore + return _rope_forward_oot( + self, + positions, + query, + key, + offsets, + is_neox_style_override, + ) class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): @@ -328,6 +311,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) + q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets, + neox_style) return q_pe, k_pe