Fix RotaryEmbedding for fp32 input (#11843)
This commit is contained in:
@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
|
|
||||||
if (
|
if dtype == torch.float32 or (
|
||||||
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
|
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
|
||||||
and not (_is_cpu and _is_cpu_amx_available)
|
and not (_is_cpu and _is_cpu_amx_available)
|
||||||
and not _is_xpu
|
and not _is_xpu
|
||||||
@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp):
|
|||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
if (
|
||||||
|
_is_cuda
|
||||||
|
and (self.head_size in [64, 128, 256, 512])
|
||||||
|
and self.dtype != torch.float32
|
||||||
|
):
|
||||||
apply_rope_with_cos_sin_cache_inplace(
|
apply_rope_with_cos_sin_cache_inplace(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
query=query,
|
query=query,
|
||||||
|
|||||||
Reference in New Issue
Block a user