From f796e6280b79cc87451ceac2daa086ee80b1d572 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Mon, 25 Aug 2025 09:32:35 +0800 Subject: [PATCH] [CustomOp] Register RotaryEmbedding instead of overwrite forward (#2385) ### What this PR does / why we need it? Register RotaryEmbedding instead of overwrite forward ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/808d2e9aa0f302bf9667b09b9dcf297f86927dac --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> --- .github/workflows/vllm_ascend_test.yaml | 2 +- tests/ut/ops/test_rotary_embedding.py | 244 ++++++----- tests/ut/test_utils.py | 4 +- vllm_ascend/ops/__init__.py | 11 +- vllm_ascend/ops/rotary_embedding.py | 539 ++++++++++++------------ vllm_ascend/utils.py | 7 + 6 files changed, 426 insertions(+), 381 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 78cfefa..f1be625 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -287,4 +287,4 @@ jobs: pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ --ignore=tests/e2e/multicard/test_data_parallel.py \ - --ignore=tests/e2e/multicard/test_offline_inference_310p.py + --ignore=tests/e2e/multicard/test_offline_inference_310p.py \ No newline at end of file diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 3b388e0..e8bb918 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -1,17 +1,16 @@ import math -from unittest.mock import MagicMock, patch +import unittest +from unittest.mock import MagicMock, PropertyMock, patch import torch +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase -from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, - native_rope_deepseek_forward, - rope_forward_oot, rotate_half, - yarn_find_correction_dim, - yarn_get_mscale) +from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled -class TestCustomRotaryEmbeddingEnabled(TestBase): +class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): def setUp(self): # Common setup for tests @@ -66,22 +65,28 @@ class TestCustomRotaryEmbeddingEnabled(TestBase): self.assertFalse(result) -class TestRopeForwardOot(TestBase): +class TestAscendRotaryEmbedding(unittest.TestCase): def setUp(self): # Common setup for tests self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 4, dtype=torch.float16) - self.key = torch.randn(3, 4, dtype=torch.float16) + self.query = torch.randn(3, 1, 32, dtype=torch.float16) + self.key = torch.randn(3, 1, 32, dtype=torch.float16) self.head_size = 32 - self.cos_sin_cache = torch.randn(3, 4) + self.rotary_dim = self.head_size + self.max_position = 16 + self.rope_theta = 10000 + self.is_neox_style = True + self.cos_sin_cache = torch.randn(3, 1, 32) + self.layer = RotaryEmbedding(self.head_size, self.rotary_dim, + self.max_position, self.rope_theta, + self.is_neox_style, torch.float16) # Mock self object for rope_forward_oot self.mock_self = MagicMock() self.mock_self.head_size = self.head_size self.mock_self.cos_sin_cache = self.cos_sin_cache - self.mock_self.is_neox_style = True - self.mock_self.forward_native.return_value = (self.query, self.key) + self.mock_self.is_neox_style = self.is_neox_style @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') def test_rope_forward_oot_torchair_enabled_base(self, @@ -90,12 +95,14 @@ class TestRopeForwardOot(TestBase): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = True mock_get_ascend_config.return_value = mock_config + with patch.object(self.layer, + "forward_native", + return_value=(self.query, + self.key)) as mock_forward_native: + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - self.query, self.key) - - self.mock_self.forward_native.assert_called_once_with( - self.positions, self.query, self.key, None) + mock_forward_native.assert_called_once() self.assertTrue(torch.equal(result_q, self.query)) self.assertTrue(torch.equal(result_k, self.key)) @@ -116,9 +123,10 @@ class TestRopeForwardOot(TestBase): mock__c.rotary_embedding.return_value = self.query, self.key - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - self.query, self.key) + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) + mock__c.rotary_embedding.assert_called_once() self.assertEqual(result_q.shape, self.query.shape) self.assertEqual(result_k.shape, self.key.shape) @@ -137,8 +145,9 @@ class TestRopeForwardOot(TestBase): non_contig_query = self.query.transpose(0, 1) non_contig_key = self.key.transpose(0, 1) - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - non_contig_query, non_contig_key) + result_q, result_k = self.layer.forward(self.positions, + non_contig_query, + non_contig_key) mock_npu_rotary.assert_called_once() self.assertEqual(result_q.shape, non_contig_query.shape) @@ -153,8 +162,7 @@ class TestRopeForwardOot(TestBase): # Test that NotImplementedError is raised when offsets is provided offsets = torch.tensor([1, 2, 3]) with self.assertRaises(NotImplementedError): - rope_forward_oot(self.mock_self, self.positions, self.query, - self.key, offsets) + self.layer.forward(self.positions, self.query, self.key, offsets) @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', @@ -168,11 +176,10 @@ class TestRopeForwardOot(TestBase): mock_get_ascend_config.return_value = mock_config # Test neox_style override - result_q, result_k = rope_forward_oot(self.mock_self, - self.positions, - self.query, - self.key, - is_neox_style_override=False) + result_q, result_k = self.layer.forward(self.positions, + self.query, + self.key, + is_neox_style_override=False) # Check that neox_style=False was passed to the NPU function args, kwargs = mock_npu_rotary.call_args @@ -190,98 +197,118 @@ class MockRopeModule: self.base = 1 -class TestNativeRopeDeepseekForward(TestBase): +class TestAscendDeepseekScalingRotaryEmbedding(TestBase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 1, 32, dtype=torch.float16) + self.key = torch.randn(3, 1, 32, dtype=torch.float16) + self.head_size = 32 + self.rotary_dim = self.head_size + self.max_position = 16 + self.rope_theta = 10000 + self.is_neox_style = True + self.scaling_factor = 1 + self.layer = None + + def _create_layer(self): + self.layer = DeepseekScalingRotaryEmbedding( + self.head_size, self.rotary_dim, self.max_position, + self.rope_theta, self.is_neox_style, self.scaling_factor, + torch.float16) + return self.layer + + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) + def test_native_rope_deepseek_forward_base(self, mock_npuplatform): + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() + with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot", + return_value=(self.query, + self.key)) as mock_rope_forward_oot: + q_pe, k_pe = self.layer.forward(self.positions, self.query, + self.key) + mock_rope_forward_oot.assert_called_once() + assert q_pe.shape == self.query.shape + assert k_pe.shape == self.key.shape @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot): - module = MockRopeModule() - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape - - @patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache') - @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) def test_native_rope_deepseek_forward_cache_handling( - self, mock_rope_forward_oot, mock_set_cache): + self, mock_npuplatform, mock_rope_forward_oot): + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() + self.layer.max_seq_len = 1024 # Test cache situation is true - module = MockRopeModule(max_seq_len=1024) - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) + with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache: + mock_rope_forward_oot.return_value = (self.query, self.key) - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, - positions, - query, - key, - max_seq_len=2048) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape + q_pe, k_pe = self.layer.forward(self.positions, + self.query, + self.key, + max_seq_len=2048) + mock_set_cache.assert_called_once() + assert q_pe.shape == self.query.shape + assert k_pe.shape == self.key.shape @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) def test_native_rope_deepseek_forward_key_reshaping( - self, mock_rope_forward_oot): - module = MockRopeModule() - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 128) + self, mock_npuplatform, mock_rope_forward_oot): + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() - mock_rope_forward_oot.return_value = (query, key) + key = torch.randn(1, 32) - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) + mock_rope_forward_oot.return_value = (self.query, key) - assert q_pe.shape == query.shape - assert k_pe.shape == (1, 128) - - @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_non_neox_style( - self, mock_rope_forward_oot): - module = MockRopeModule(is_neox_style=False) - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) - - assert q_pe.shape == query.shape + q_pe, k_pe = self.layer.forward(self.positions, self.query, key) + mock_rope_forward_oot.assert_called_once() + assert q_pe.shape == self.query.shape assert k_pe.shape == key.shape + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) + def test_native_rope_deepseek_forward_non_neox_style( + self, mock_npuplatform, mock_rope_forward_oot): + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() -class TestRotateHalf(TestBase): + mock_rope_forward_oot.return_value = (self.query, self.key) - def test_rotate_half_even_dim(self): - # Test with even dimension - x = torch.tensor([1.0, 2.0, 3.0, 4.0]) - expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) - result = rotate_half(x) - self.assertTrue(torch.allclose(result, expected)) + q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key) + mock_rope_forward_oot.assert_called_once() + assert q_pe.shape == self.query.shape + assert k_pe.shape == self.key.shape -class TestYarnFindCorrectionDim(TestBase): - - def test_basic_case(self): + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) + def test_basic_case(self, mock_npuplatform): # Test with standard values + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() num_rotations = 100 dim = 512 base = 10000 max_position_embeddings = 2048 - result = yarn_find_correction_dim(num_rotations, dim, base, - max_position_embeddings) + result = self.layer._yarn_find_correction_dim(num_rotations, dim, base, + max_position_embeddings) # Calculate expected value manually expected = (dim * torch.log( @@ -291,22 +318,27 @@ class TestYarnFindCorrectionDim(TestBase): self.assertTrue(torch.allclose(result, expected)) + @patch("vllm.platforms.current_platform.device_type", + new=torch.device("cpu")) + @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", + new_callable=PropertyMock) + def test_yarn_get_mscale(self, mock_npuplatform): + mock_npuplatform.device_type = torch.device("cpu") + self.layer = self._create_layer() -class TestYarnGetMscale(TestBase): + # test_scale_less_than_or_equal_1 + self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0) + self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0) + self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0) - def test_scale_less_than_or_equal_1(self): - self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) - self.assertEqual(yarn_get_mscale(scale=1.0), 1.0) - self.assertEqual(yarn_get_mscale(scale=0.999), 1.0) - - def test_scale_greater_than_1(self): + # test_scale_greater_than_1: test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), (math.e, 1.0, 1.0 + 0.1)] for scale, mscale, expected in test_cases: - result = yarn_get_mscale(scale, mscale) + result = self.layer._yarn_get_mscale(scale, mscale) self.assertAlmostEqual( result, expected, diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index db4fcc3..46a3ca8 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -356,13 +356,13 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 6) + self.assertEqual(mock_customop.register_oot.call_count, 8) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 6) + self.assertEqual(mock_customop.register_oot.call_count, 8) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index c946e8d..a1e7417 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -17,12 +17,13 @@ import torch -import vllm_ascend.ops.activation # noqa import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa -import vllm_ascend.ops.rotary_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa +from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul +from vllm_ascend.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) class dummyFusionOp: @@ -47,3 +48,9 @@ def register_dummy_fusion_op() -> None: name="fused_add_rms_norm_static_fp8_quant") torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp( name="rms_norm_dynamic_per_token_quant") + + +__all__ = [ + "AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", + "AscendDeepseekScalingRotaryEmbedding" +] diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 806a210..079c356 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import enable_custom_op, is_310p @@ -89,167 +90,7 @@ def rope_forward_oot( return query.view(query_shape), key.view(key_shape) -def native_rope_deepseek_forward(self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) - if len(key.shape) == 2: - key = key[:, None, :] - # Note: we implement the non neox_style method with shuffle the last dim and neox style - # calculation method which is also more compute friendly to the ascend machine - # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py - neox_style = True - if self.is_neox_style is False: - b, h_q, d = query.shape - query = query.view(b, h_q, d // 2, 2).transpose(3, - 2).reshape(b, h_q, d) - b, h_k, d = key.shape - key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) - return q_pe, k_pe - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - return (dim * torch.log( - torch.tensor(max_position_embeddings) / - (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, - high_rot, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - low = torch.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = torch.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - # Note: use torch instead of max/min to solve MTP compilation error. - return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) - - -def yarn_linear_ramp_mask(min_value, max_value, dim): - # Note: The if conditional branch is not used here - # to solve MTP compilation error. - max_value += (min_value == max_value).float() * 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids] - sin = sin[position_ids] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] - - if len(q.shape) == 3: - q = q[:, :, None, :] - if len(k.shape) == 2: - k = k[:, None, None, :] - elif len(k.shape) == 3: - k = k[:, :, None, :] - - b, h_q, s, d = q.shape - q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) - - b, h_k, s, d = k.shape - k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = q_embed.view(b, h_q, d) - k_embed = k_embed.view(b, h_k, d) - - return q_embed, k_embed - - -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - dim = self.rotary_dim - - freq_extra = 1.0 / (self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale - sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale - cos_cached = cos_cached.to(dtype) - sin_cached = sin_cached.to(dtype) - cache = torch.cat([freqs.cos() * self.mscale, - freqs.sin() * self.mscale], - dim=-1).to(dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - self.register_buffer("cos_cached", cos_cached, persistent=False) - self.register_buffer("sin_cached", sin_cached, persistent=False) - - -def __set_cos_sin_cache(self, seq_len, device, dtype): +def set_cos_sin_cache(self, seq_len, device, dtype): inv_freq = 1.0 / (self.base**(torch.arange( 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * (1 / self.rotary_dim))) @@ -266,117 +107,275 @@ def __set_cos_sin_cache(self, seq_len, device, dtype): self.embed = F.embedding -_original_re_init = RotaryEmbedding.__init__ +class AscendRotaryEmbedding(RotaryEmbedding): - -def qwen_rope_init_func( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, -) -> None: - _original_re_init(self, head_size, rotary_dim, max_position_embeddings, - base, is_neox_style, dtype) - if get_ascend_config().torchair_graph_config.enabled: - __set_cos_sin_cache(self, - seq_len=max_position_embeddings, - device="npu", - dtype=dtype) - - -def rope_forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None, - max_seq_len: Optional[int] = None, - is_prefill: Optional[bool] = True, - is_qwen_torchair: Optional[bool] = False, -): - if get_ascend_config().torchair_graph_config.enabled \ - and is_qwen_torchair and not is_prefill: - if max_seq_len is not None and torch.gt(max_seq_len, - self.max_position_embeddings): - __set_cos_sin_cache(self, - seq_len=max_seq_len, - device=query.device, - dtype=torch.float32) - - # bsnd/bnsd - if positions is not None: - cos = self.embed(positions, self.cos) - sin = self.embed(positions, self.sin) - self.cos_embed = cos - self.sin_embed = sin - else: - cos = self.cos_embed - sin = self.sin_embed - - query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() - key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() - - cos = cos.unsqueeze(-2).unsqueeze(-2) - sin = sin.unsqueeze(-2).unsqueeze(-2) - - query = query.unsqueeze(1) - key = key.unsqueeze(1) - - q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( - query, key, cos, sin) - return q_embed.flatten(-2), k_embed.flatten(-2) - else: - return rope_forward_oot(self, positions, query, key, offsets, - is_neox_style_override, - is_qwen_torchair) # type: ignore - - -def deepseek_rope_init_func( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1, - mscale_all_dim: float = 0, -) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation. - self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super(DeepseekScalingRotaryEmbedding, - self).__init__(head_size, rotary_dim, max_position_embeddings, base, + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") + if get_ascend_config().torchair_graph_config.enabled: + set_cos_sin_cache(self, + seq_len=max_position_embeddings, + device="npu", + dtype=dtype) + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + max_seq_len: Optional[int] = None, + is_prefill: Optional[bool] = True, + is_qwen_torchair: Optional[bool] = False, + ): + if get_ascend_config().torchair_graph_config.enabled \ + and is_qwen_torchair and not is_prefill: + if max_seq_len is not None and torch.gt( + max_seq_len, self.max_position_embeddings): + set_cos_sin_cache(self, + seq_len=max_seq_len, + device=query.device, + dtype=torch.float32) + + # bsnd/bnsd + if positions is not None: + cos = self.embed(positions, self.cos) + sin = self.embed(positions, self.sin) + self.cos_embed = cos + self.sin_embed = sin + else: + cos = self.cos_embed + sin = self.sin_embed + + query = query.view(*query.shape[:-1], -1, + self.head_size).contiguous() + key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + query = query.unsqueeze(1) + key = key.unsqueeze(1) + + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( + query, key, cos, sin) + return q_embed.flatten(-2), k_embed.flatten(-2) + else: + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore -RotaryEmbedding.__init__ = qwen_rope_init_func -RotaryEmbedding.forward_oot = rope_forward +class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): -# Note: we adopt the native huggingface deepseek rope initialization code from -# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for -# its more ascend compute friendly -DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func -DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + # Note: we adopt the native huggingface deepseek rope initialization code from + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for + # its more ascend compute friendly + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + self._yarn_get_mscale(self.scaling_factor, float(mscale)) / + self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) + self.max_seq_len = max_position_embeddings + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=NPUPlatform.device_type, + dtype=dtype) + + def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def _rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _yarn_linear_ramp_mask(self, min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Inverse dim formula to find dim based on number of rotations + def _yarn_find_correction_dim(self, + num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + # Find dim range bounds based on rotations + def _yarn_find_correction_range(self, + low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + self._yarn_find_correction_dim(low_rot, dim, base, + max_position_embeddings)) + high = torch.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb + def _apply_rotary_pos_emb(self, + q, + k, + cos, + sin, + position_ids, + unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + if len(q.shape) == 3: + q = q[:, :, None, :] + if len(k.shape) == 2: + k = k[:, None, None, :] + elif len(k.shape) == 3: + k = k[:, :, None, :] + + b, h_q, s, d = q.shape + q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) + + b, h_k, s, d = k.shape + k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) + + q_embed = (q * cos) + (self._rotate_half(q) * sin) + k_embed = (k * cos) + (self._rotate_half(k) * sin) + + q_embed = q_embed.view(b, h_q, d) + k_embed = k_embed.view(b, h_k, d) + + return q_embed, k_embed + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask( + low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - + inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len * self.scaling_factor, + device=device, + dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale + sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale + cos_cached = cos_cached.to(dtype) + sin_cached = sin_cached.to(dtype) + cache = torch.cat( + [freqs.cos() * self.mscale, + freqs.sin() * self.mscale], dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + def forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None): + if max_seq_len is not None and max_seq_len > self.max_seq_len: + self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, + 2).transpose(3, 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, + 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) + return q_pe, k_pe diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9a1647c..0e799e6 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -478,9 +478,16 @@ def register_ascend_customop(): from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, AscendMlpMergedColumnParallelLinear, AscendMlpRowParallelLinear) + from vllm_ascend.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, name="SiluAndMul") + CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding, + name="RotaryEmbedding") + CustomOp.register_oot( + _decorated_op_cls=AscendDeepseekScalingRotaryEmbedding, + name="DeepseekScalingRotaryEmbedding") if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, name="ColumnParallelLinear")