[bugfix] fix deepseek rope sincoscache re-generation (#2744)
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.
- vLLM version: v0.10.1.1
- vLLM main:
60f0843ef8
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -168,8 +168,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
self._set_cos_sin_cache(self.max_seq_len,
|
||||
device=NPUPlatform.device_type,
|
||||
dtype=dtype)
|
||||
|
||||
@@ -275,8 +277,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
@@ -297,9 +298,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
@@ -317,10 +316,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
|
||||
Reference in New Issue
Block a user