[BugFix] fix length of sin/cos cache in rope (#1266)

This PR fixes the bug that constructs shorter sin/cos cache than model's
max positional embedding.

Closes: https://github.com/vllm-project/vllm-ascend/issues/1038

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-06-17 23:14:25 +08:00
committed by GitHub
parent afc8edb046
commit d7e19ed57a

View File

@@ -219,7 +219,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, device=device, dtype=torch.float32)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale