Fix RotaryEmbedding when using Triton backend for EXAONE-3.5-2.4B (#4064)
This commit is contained in:
@@ -148,7 +148,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if _is_cuda_available:
|
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
|
||||||
apply_rope_with_cos_sin_cache_inplace(
|
apply_rope_with_cos_sin_cache_inplace(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
query=query,
|
query=query,
|
||||||
|
|||||||
Reference in New Issue
Block a user