[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)

Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com>
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
Hubert Lu
2025-07-24 23:44:28 -07:00
committed by GitHub
parent 7ad6b766c5
commit af4b9bae95
17 changed files with 1226 additions and 61 deletions

View File

@@ -78,13 +78,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def(

View File

@@ -13,70 +13,158 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <flashinfer/activation.cuh>
#include "pytorch_extension_utils.h"
#include "utils.h"
using namespace flashinfer;
#else
#include "hip_act_and_mul.cuh"
#endif
__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
// Adapted from flashinfer activation
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
namespace detail {
template <typename T>
__device__ __forceinline__ float to_f32(const T& x) {
#if USE_ROCM
return castToFloat(x);
#else
return static_cast<float>(x);
#endif
}
__device__ __forceinline__ float gelu(const float& val) {
template <typename T>
__device__ __forceinline__ T from_f32(float f32) {
#if USE_ROCM
return castFromFloat<T>(f32);
#else
return static_cast<T>(f32);
#endif
}
} // namespace detail
template <typename T>
__device__ __forceinline__ T silu(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
}
template <typename T>
__device__ __forceinline__ T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}
__device__ __forceinline__ float gelu_tanh(const float& val) {
const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
__device__ __forceinline__ T gelu_quick_act(const T& x) {
float f32_val = detail::to_f32(x);
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
template <typename T>
__device__ __forceinline__ T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = detail::to_f32(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return detail::from_f32<T>(f32_val * cdf);
}
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
#if USE_ROCM
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#else
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
#endif
return true;
});
}
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) {
int d = input.size(-1) / 2;
#if USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
return true;
});
}
#endif

View File

@@ -19,6 +19,20 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/activation
*/
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/allreduce
*/