Fix RotaryEmbedding for fp32 input (#11843)
This commit is contained in:
@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp):
|
||||
if not _is_cuda:
|
||||
cache = cache.to(dtype)
|
||||
|
||||
if (
|
||||
if dtype == torch.float32 or (
|
||||
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
|
||||
and not (_is_cpu and _is_cpu_amx_available)
|
||||
and not _is_xpu
|
||||
@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp):
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||
if (
|
||||
_is_cuda
|
||||
and (self.head_size in [64, 128, 256, 512])
|
||||
and self.dtype != torch.float32
|
||||
):
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
|
||||
Reference in New Issue
Block a user