diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ea47c04..4b76dce 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -138,8 +138,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): forward_context = get_forward_context() is_first_layer = forward_context.is_first_layer # Generate cos and sin outside layers to avoid repeated calculation. - if is_neox_style and \ - self.head_size == 128: + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128: if is_first_layer: cos_sin = self.cos_sin_cache.index_select(0, positions) last_dim = cos_sin.size()[-1]