From 4df8df5b945da2eb29f95f46ad85836b8c71bf62 Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:03:34 +0800 Subject: [PATCH] [bugfix] fix deepseek rope sincoscache re-generation (#2744) ### What this PR does / why we need it? The current implementation will result in duplicate generation of `sin_cos_cache` in rope when `kv_seqlen` > 4k, because the initialization length of the `sin_cos_cache` is only 4k. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? After this PR merged, sin_cos_cache will not increase in forward func, so `test_native_rope_deepseek_forward_cache_handling` is not necessary. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/60f0843ef8fb4b0c4e6788acc042873a0a2ea2a1 Signed-off-by: zzzzwwjj <1183291235@qq.com> --- tests/ut/ops/test_rotary_embedding.py | 44 ++++++++-------- .../ops/test_torchair_rotary_embedding.py | 51 +++++++++---------- vllm_ascend/ops/rotary_embedding.py | 18 +++---- .../torchair/ops/torchair_rotary_embedding.py | 21 +++----- vllm_ascend/torchair/torchair_mla.py | 8 +-- 5 files changed, 63 insertions(+), 79 deletions(-) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index eb48c81..39bc76f 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -157,6 +157,28 @@ class TestAscendRotaryEmbedding(unittest.TestCase): args, kwargs = mock_npu_rotary.call_args self.assertFalse(args[-1]) + @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_rotary_dim_less_than_head_size( + self, mock_npu_rotary, mock_custom_enabled): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + # test case when rotary_dim < head_size + org_rotary_dim = self.layer.rotary_dim + self.layer.rotary_dim = self.layer.head_size // 2 + + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) + + mock_npu_rotary.assert_called_once() + self.assertEqual(result_q.shape, self.query.shape) + self.assertEqual(result_k.shape, self.key.shape) + + # restore rotary_dim + self.layer.rotary_dim = org_rotary_dim + class MockRopeModule: @@ -207,28 +229,6 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): assert q_pe.shape == self.query.shape assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_native_rope_deepseek_forward_cache_handling( - self, mock_npuplatform, mock_rope_forward_oot): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - self.layer.max_seq_len = 1024 - # Test cache situation is true - with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache: - mock_rope_forward_oot.return_value = (self.query, self.key) - - q_pe, k_pe = self.layer.forward(self.positions, - self.query, - self.key, - max_seq_len=2048) - mock_set_cache.assert_called_once() - assert q_pe.shape == self.query.shape - assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') @patch("vllm.platforms.current_platform.device_type", new=torch.device("cpu")) diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py index e7c68f7..ce74dee 100644 --- a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py +++ b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py @@ -5,8 +5,9 @@ import torch from tests.ut.base import TestBase from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( - custom_rotary_embedding_enabled, native_rope_deepseek_forward, - rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale) + _set_cos_sin_cache, custom_rotary_embedding_enabled, + native_rope_deepseek_forward, rope_forward_oot, rotate_half, + yarn_find_correction_dim, yarn_get_mscale) class TestCustomRotaryEmbeddingEnabled(TestBase): @@ -200,6 +201,28 @@ class MockRopeModule: self.sin_cached = None self.rotary_dim = 1 self.base = 1 + self.beta_fast = 32 + self.beta_slow = 1 + self.max_position_embeddings = 4096 + self.mscale = 1.0 + self.scaling_factor = 40 + + def register_buffer(self): + pass + + +class TestSetSinCosCache(TestBase): + + def test_set_cos_sin_cache(self): + module = MockRopeModule() + + with patch.object(module, "register_buffer") as mock_register_buffer: + _set_cos_sin_cache(module, + 1024, + device="cpu", + dtype=torch.bfloat16) + + mock_register_buffer.assert_called() class TestNativeRopeDeepseekForward(TestBase): @@ -220,30 +243,6 @@ class TestNativeRopeDeepseekForward(TestBase): assert q_pe.shape == query.shape assert k_pe.shape == key.shape - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache' - ) - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_cache_handling( - self, mock_rope_forward_oot, mock_set_cache): - # Test cache situation is true - module = MockRopeModule(max_seq_len=1024) - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, - positions, - query, - key, - max_seq_len=2048) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape - @patch( 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') def test_native_rope_deepseek_forward_key_reshaping( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 89e2bc7..b982a7e 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -168,8 +168,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - self._set_cos_sin_cache(seq_len=max_position_embeddings, + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + self._set_cos_sin_cache(self.max_seq_len, device=NPUPlatform.device_type, dtype=dtype) @@ -275,8 +277,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): return q_embed, k_embed - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len + def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -297,9 +298,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -317,10 +316,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py index 5793288..766ae5f 100644 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -93,10 +93,7 @@ def native_rope_deepseek_forward(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style @@ -211,8 +208,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len +def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -232,9 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -365,8 +359,7 @@ def deepseek_rope_init_func( super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 95ca3bd..80ada4d 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -1198,9 +1198,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) + decode_q_pe.contiguous(), decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ @@ -1225,9 +1223,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) + prefill_q_pe.contiguous(), prefill_k_pe) assert len( kv_cache