[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -19,7 +19,20 @@ limitations under the License.
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <sstream>
|
||||
#ifdef USE_ROCM
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = __hip_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Adapt from FlashInfer
|
||||
@@ -31,7 +44,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_F16(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_F16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_BF16
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
@@ -41,7 +54,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_BF16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E4M3
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
|
||||
@@ -51,7 +64,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_FP8_E4M3
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E5M2
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
|
||||
@@ -61,7 +74,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_FP8_E5M2
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
@@ -197,7 +210,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
inline bool is_float8_tensor(const at::Tensor& tensor) {
|
||||
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
|
||||
}
|
||||
#endif
|
||||
#endif // USE_ROCM
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
@@ -267,7 +280,6 @@ inline bool getEnvEnablePDL() {
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
@@ -284,7 +296,6 @@ inline bool getEnvEnablePDL() {
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
@@ -297,52 +308,99 @@ inline bool getEnvEnablePDL() {
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize // 64
|
||||
#endif
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#include "hip_math_def.h"
|
||||
#include "hip_vec_dtypes.h"
|
||||
|
||||
#else
|
||||
|
||||
template <typename srcDtype>
|
||||
__device__ __forceinline__ float castToFloat(srcDtype val) {
|
||||
return static_cast<srcDtype>(val);
|
||||
}
|
||||
|
||||
template <typename dstDtype>
|
||||
__device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
return static_cast<dstDtype>(val);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// add FP8 support
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
|
||||
#else // USE_ROCM
|
||||
|
||||
#if HIP_FP8_TYPE_FNUZ
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
#else
|
||||
#if HIP_FP8_TYPE_E4M3
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#error "fp8 is not supported in this processor (arch < gfx942)."
|
||||
#endif // HIP_FP8_TYPE_E4M3
|
||||
#endif // HIP_FP8_TYPE_FNUZ
|
||||
#endif // USE_ROCM
|
||||
|
||||
#define FULL_MASK 0xffffffff
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
#ifndef USE_ROCM
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
return old;
|
||||
#else
|
||||
int* addr_as_i = (int*)addr;
|
||||
int old = *addr_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
|
||||
return max_value;
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float max_value) {
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) max_value = warpReduceMax(max_value);
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return max_value;
|
||||
return value;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Pads to a multiple of `alignment` rows.
|
||||
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
|
||||
|
||||
Reference in New Issue
Block a user