From 2c685e3b61b0dae7ab26a95913cee1ba37bdd7c6 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 9 May 2025 12:55:57 +0800 Subject: [PATCH] [Bugfix] Correct method call for _set_cos_sin_cache (#774) This change ensures proper functionality for longer sequences by correctly invoking the _set_cos_sin_cache method with self as the first argument. For example, with DeepSeek R1, if this change isn't made, the program will crash when the input sequence exceeds 4096. Signed-off-by: Jade Zheng --- vllm_ascend/ops/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 0dbe940..0c2a00a 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -82,7 +82,7 @@ def native_rope_deepseek_forward(self, offsets: Optional[torch.Tensor] = None, max_seq_len: Optional[int] = None): if max_seq_len is not None and max_seq_len > self.max_seq_len: - self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style