[BugFix] fix length of sin/cos cache in rope (#1266)
This PR fixes the bug that constructs shorter sin/cos cache than model's max positional embedding. Closes: https://github.com/vllm-project/vllm-ascend/issues/1038 Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -219,7 +219,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|||||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
t = torch.arange(seq_len * self.scaling_factor,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||||
|
|||||||
Reference in New Issue
Block a user