Fix RotaryEmbedding for fp32 input (#11843)

This commit is contained in:
zhangdonghao-zdh
2025-10-21 10:56:48 +08:00
committed by GitHub
parent 8374a96e49
commit fb6cc7b000

View File

@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda:
cache = cache.to(dtype)
if (
if dtype == torch.float32 or (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available)
and not _is_xpu
@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp):
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
if (
_is_cuda
and (self.head_size in [64, 128, 256, 512])
and self.dtype != torch.float32
):
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,