[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)

Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com>
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
Hubert Lu
2025-07-24 23:44:28 -07:00
committed by GitHub
parent 7ad6b766c5
commit af4b9bae95
17 changed files with 1226 additions and 61 deletions

View File

@@ -31,6 +31,10 @@ from sgl_kernel.elementwise import (
silu_and_mul,
)
from sgl_kernel.fused_moe import fused_marlin_moe
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,

View File

@@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
return out
@@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
return out
@@ -209,10 +209,34 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
return out
if torch.version.hip is not None:
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
"""
Quick-GELU: y = x * sigmoid(1.702 * x)
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
so the last-dimension byte length must be a multiple of 16 bytes.
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError(
f"The last dimension ({input.shape[-1]}) x itemsize "
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
)
if out is not None:
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
else:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gelu_quick(out, input)
return out
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,