[bugfix] fix deepseek rope sincoscache re-generation (#2744)
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.
- vLLM version: v0.10.1.1
- vLLM main:
60f0843ef8
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -157,6 +157,28 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
|||||||
args, kwargs = mock_npu_rotary.call_args
|
args, kwargs = mock_npu_rotary.call_args
|
||||||
self.assertFalse(args[-1])
|
self.assertFalse(args[-1])
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||||
|
return_value=False)
|
||||||
|
@patch('torch_npu._npu_rotary_embedding')
|
||||||
|
def test_rope_forward_oot_rotary_dim_less_than_head_size(
|
||||||
|
self, mock_npu_rotary, mock_custom_enabled):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.torchair_graph_config.enabled = False
|
||||||
|
|
||||||
|
# test case when rotary_dim < head_size
|
||||||
|
org_rotary_dim = self.layer.rotary_dim
|
||||||
|
self.layer.rotary_dim = self.layer.head_size // 2
|
||||||
|
|
||||||
|
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||||
|
self.key)
|
||||||
|
|
||||||
|
mock_npu_rotary.assert_called_once()
|
||||||
|
self.assertEqual(result_q.shape, self.query.shape)
|
||||||
|
self.assertEqual(result_k.shape, self.key.shape)
|
||||||
|
|
||||||
|
# restore rotary_dim
|
||||||
|
self.layer.rotary_dim = org_rotary_dim
|
||||||
|
|
||||||
|
|
||||||
class MockRopeModule:
|
class MockRopeModule:
|
||||||
|
|
||||||
@@ -207,28 +229,6 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
|||||||
assert q_pe.shape == self.query.shape
|
assert q_pe.shape == self.query.shape
|
||||||
assert k_pe.shape == self.key.shape
|
assert k_pe.shape == self.key.shape
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
|
||||||
@patch("vllm.platforms.current_platform.device_type",
|
|
||||||
new=torch.device("cpu"))
|
|
||||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
|
||||||
new_callable=PropertyMock)
|
|
||||||
def test_native_rope_deepseek_forward_cache_handling(
|
|
||||||
self, mock_npuplatform, mock_rope_forward_oot):
|
|
||||||
mock_npuplatform.device_type = torch.device("cpu")
|
|
||||||
self.layer = self._create_layer()
|
|
||||||
self.layer.max_seq_len = 1024
|
|
||||||
# Test cache situation is true
|
|
||||||
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
|
|
||||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
|
||||||
|
|
||||||
q_pe, k_pe = self.layer.forward(self.positions,
|
|
||||||
self.query,
|
|
||||||
self.key,
|
|
||||||
max_seq_len=2048)
|
|
||||||
mock_set_cache.assert_called_once()
|
|
||||||
assert q_pe.shape == self.query.shape
|
|
||||||
assert k_pe.shape == self.key.shape
|
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||||
@patch("vllm.platforms.current_platform.device_type",
|
@patch("vllm.platforms.current_platform.device_type",
|
||||||
new=torch.device("cpu"))
|
new=torch.device("cpu"))
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import torch
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||||
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
|
_set_cos_sin_cache, custom_rotary_embedding_enabled,
|
||||||
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
|
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
|
||||||
|
yarn_find_correction_dim, yarn_get_mscale)
|
||||||
|
|
||||||
|
|
||||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||||
@@ -200,6 +201,28 @@ class MockRopeModule:
|
|||||||
self.sin_cached = None
|
self.sin_cached = None
|
||||||
self.rotary_dim = 1
|
self.rotary_dim = 1
|
||||||
self.base = 1
|
self.base = 1
|
||||||
|
self.beta_fast = 32
|
||||||
|
self.beta_slow = 1
|
||||||
|
self.max_position_embeddings = 4096
|
||||||
|
self.mscale = 1.0
|
||||||
|
self.scaling_factor = 40
|
||||||
|
|
||||||
|
def register_buffer(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetSinCosCache(TestBase):
|
||||||
|
|
||||||
|
def test_set_cos_sin_cache(self):
|
||||||
|
module = MockRopeModule()
|
||||||
|
|
||||||
|
with patch.object(module, "register_buffer") as mock_register_buffer:
|
||||||
|
_set_cos_sin_cache(module,
|
||||||
|
1024,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
mock_register_buffer.assert_called()
|
||||||
|
|
||||||
|
|
||||||
class TestNativeRopeDeepseekForward(TestBase):
|
class TestNativeRopeDeepseekForward(TestBase):
|
||||||
@@ -220,30 +243,6 @@ class TestNativeRopeDeepseekForward(TestBase):
|
|||||||
assert q_pe.shape == query.shape
|
assert q_pe.shape == query.shape
|
||||||
assert k_pe.shape == key.shape
|
assert k_pe.shape == key.shape
|
||||||
|
|
||||||
@patch(
|
|
||||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
|
|
||||||
)
|
|
||||||
@patch(
|
|
||||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
|
||||||
def test_native_rope_deepseek_forward_cache_handling(
|
|
||||||
self, mock_rope_forward_oot, mock_set_cache):
|
|
||||||
# Test cache situation is true
|
|
||||||
module = MockRopeModule(max_seq_len=1024)
|
|
||||||
positions = torch.tensor([1, 2, 3])
|
|
||||||
query = torch.randn(1, 8, 128)
|
|
||||||
key = torch.randn(1, 8, 128)
|
|
||||||
|
|
||||||
mock_rope_forward_oot.return_value = (query, key)
|
|
||||||
|
|
||||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
max_seq_len=2048)
|
|
||||||
|
|
||||||
assert q_pe.shape == query.shape
|
|
||||||
assert k_pe.shape == key.shape
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||||
def test_native_rope_deepseek_forward_key_reshaping(
|
def test_native_rope_deepseek_forward_key_reshaping(
|
||||||
|
|||||||
@@ -168,8 +168,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
super(DeepseekScalingRotaryEmbedding,
|
super(DeepseekScalingRotaryEmbedding,
|
||||||
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
||||||
base, is_neox_style, dtype)
|
base, is_neox_style, dtype)
|
||||||
self.max_seq_len = max_position_embeddings
|
|
||||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||||
|
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||||
|
self._set_cos_sin_cache(self.max_seq_len,
|
||||||
device=NPUPlatform.device_type,
|
device=NPUPlatform.device_type,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
@@ -275,8 +277,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
dim = self.rotary_dim
|
dim = self.rotary_dim
|
||||||
|
|
||||||
freq_extra = 1.0 / (self.base**(
|
freq_extra = 1.0 / (self.base**(
|
||||||
@@ -297,9 +298,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
inv_freq_mask) + freq_extra * inv_freq_mask
|
inv_freq_mask) + freq_extra * inv_freq_mask
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
t = torch.arange(seq_len * self.scaling_factor,
|
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||||
device=device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||||
@@ -317,10 +316,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None):
|
||||||
max_seq_len: Optional[int] = None):
|
|
||||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
|
||||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
|
||||||
if len(key.shape) == 2:
|
if len(key.shape) == 2:
|
||||||
key = key[:, None, :]
|
key = key[:, None, :]
|
||||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||||
|
|||||||
@@ -93,10 +93,7 @@ def native_rope_deepseek_forward(self,
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None):
|
||||||
max_seq_len: Optional[int] = None):
|
|
||||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
|
||||||
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
|
|
||||||
if len(key.shape) == 2:
|
if len(key.shape) == 2:
|
||||||
key = key[:, None, :]
|
key = key[:, None, :]
|
||||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||||
@@ -211,8 +208,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
dim = self.rotary_dim
|
dim = self.rotary_dim
|
||||||
|
|
||||||
freq_extra = 1.0 / (self.base**(
|
freq_extra = 1.0 / (self.base**(
|
||||||
@@ -232,9 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|||||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
t = torch.arange(seq_len * self.scaling_factor,
|
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||||
device=device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||||
@@ -365,8 +359,7 @@ def deepseek_rope_init_func(
|
|||||||
super(DeepseekScalingRotaryEmbedding,
|
super(DeepseekScalingRotaryEmbedding,
|
||||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
self.max_seq_len = max_position_embeddings
|
|
||||||
_set_cos_sin_cache(self,
|
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||||
max_position_embeddings,
|
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||||
dtype=dtype,
|
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
|
||||||
device="npu")
|
|
||||||
|
|||||||
@@ -1198,9 +1198,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
else:
|
else:
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions,
|
attn_metadata.decode.input_positions,
|
||||||
decode_q_pe.contiguous(),
|
decode_q_pe.contiguous(), decode_k_pe)
|
||||||
decode_k_pe,
|
|
||||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||||
@@ -1225,9 +1223,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
else:
|
else:
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.prefill.input_positions,
|
attn_metadata.prefill.input_positions,
|
||||||
prefill_q_pe.contiguous(),
|
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||||
prefill_k_pe,
|
|
||||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
|
||||||
|
|
||||||
assert len(
|
assert len(
|
||||||
kv_cache
|
kv_cache
|
||||||
|
|||||||
Reference in New Issue
Block a user