From fb6cc7b0008123c2de158927ba8d0b39ecb37940 Mon Sep 17 00:00:00 2001 From: zhangdonghao-zdh Date: Tue, 21 Oct 2025 10:56:48 +0800 Subject: [PATCH] Fix RotaryEmbedding for fp32 input (#11843) --- python/sglang/srt/layers/rotary_embedding.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index cfc626e1d..83842b0cc 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp): if not _is_cuda: cache = cache.to(dtype) - if ( + if dtype == torch.float32 or ( (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) and not (_is_cpu and _is_cpu_amx_available) and not _is_xpu @@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp): offsets: Optional[torch.Tensor] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if _is_cuda and (self.head_size in [64, 128, 256, 512]): + if ( + _is_cuda + and (self.head_size in [64, 128, 256, 512]) + and self.dtype != torch.float32 + ): apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query,