diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 0b68a2191..c5c285ca0 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -14,8 +14,6 @@ _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace -else: - from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp): # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability if not _is_cuda: cache = cache.to(dtype) + + if not _is_cuda or self.head_size not in [64, 128, 256, 512]: + from vllm._custom_ops import rotary_embedding + + self.vllm_rotary_embedding = rotary_embedding + self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp): ) else: self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - vllm_rotary_embedding( + self.vllm_rotary_embedding( positions, query, key,