[6/N][refactor]delete torchair in rotary ops (#2581)
### What this PR does / why we need it? After moved torchair related rope ops into torchair_ops, split the torchair from the origin rope ops to make the code clean. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main:ab9f2cfd19- vLLM version: v0.10.1.1 - vLLM main:81eea3d348Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -88,36 +88,16 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = self.is_neox_style
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
with patch.object(self.layer,
|
||||
"forward_native",
|
||||
return_value=(self.query,
|
||||
self.key)) as mock_forward_native:
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock_forward_native.assert_called_once()
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
@@ -130,16 +110,13 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
@@ -153,27 +130,22 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.layer.forward(self.positions, self.query, self.key, offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
|
||||
Reference in New Issue
Block a user