[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -31,6 +31,10 @@ from sgl_kernel.elementwise import (
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
|
||||
@@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
@@ -209,10 +209,34 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
Quick-GELU: y = x * sigmoid(1.702 * x)
|
||||
|
||||
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
|
||||
so the last-dimension byte length must be a multiple of 16 bytes.
|
||||
"""
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError(
|
||||
f"The last dimension ({input.shape[-1]}) x itemsize "
|
||||
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
|
||||
else:
|
||||
out = torch.empty_like(input)
|
||||
|
||||
torch.ops.sgl_kernel.gelu_quick(out, input)
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user