[AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135)
Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
87
sgl-kernel/include/hip_act_and_mul.cuh
Normal file
87
sgl-kernel/include/hip_act_and_mul.cuh
Normal file
@@ -0,0 +1,87 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define kBitsToLoad 128
|
||||
#define kBytesToLoad (kBitsToLoad / 8)
|
||||
|
||||
// Adapted from
|
||||
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
|
||||
|
||||
namespace sgl_hip {
|
||||
namespace activation {
|
||||
|
||||
template <typename T, T (*Activation)(const T&)>
|
||||
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
||||
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
const int64_t offset = token_idx * 2 * d;
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
||||
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
||||
x_vec.cast_load(input + offset + idx * vec_size);
|
||||
y_vec.cast_load(input + offset + d + idx * vec_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
|
||||
}
|
||||
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
||||
}
|
||||
|
||||
const int64_t remaining_offset = d - d % (stride * vec_size);
|
||||
// process the remaining elements
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
||||
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
|
||||
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, T (*Activation)(const T&)>
|
||||
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
||||
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
const int64_t offset = token_idx * d;
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
||||
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
||||
x_vec.cast_load(input + offset + idx * vec_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = Activation(x_vec[i]);
|
||||
}
|
||||
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
||||
}
|
||||
|
||||
const int64_t remaining_offset = d - d % (stride * vec_size);
|
||||
// process the remaining elements
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
||||
T x = input[offset + remaining_offset + idx];
|
||||
out[token_idx * d + remaining_offset + idx] = Activation(x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace activation
|
||||
} // namespace sgl_hip
|
||||
94
sgl-kernel/include/hip_math_def.h
Normal file
94
sgl-kernel/include/hip_math_def.h
Normal file
@@ -0,0 +1,94 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
namespace amdgpu {
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize);
|
||||
|
||||
template <typename srcDtype, typename destDtype>
|
||||
__forceinline__ __device__ destDtype cast(srcDtype val);
|
||||
|
||||
// specialization
|
||||
template <>
|
||||
__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) {
|
||||
return __shfl_xor(var, laneMask, width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) {
|
||||
return __shfl_xor(var, laneMask, width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(float val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ float cast(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __half cast(float fval) {
|
||||
return __float2half(fval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __hip_bfloat16 cast(float fval) {
|
||||
return __float2bfloat16(fval);
|
||||
}
|
||||
|
||||
} // namespace amdgpu
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) {
|
||||
return amdgpu::shfl_xor_sync(mask, var, laneMask, width);
|
||||
}
|
||||
|
||||
template <typename srcDtype>
|
||||
__device__ __forceinline__ float castToFloat(srcDtype val) {
|
||||
return amdgpu::cast<srcDtype, float>(val);
|
||||
}
|
||||
|
||||
template <typename dstDtype>
|
||||
__device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
return amdgpu::cast<float, dstDtype>(val);
|
||||
}
|
||||
|
||||
// operator overload to support flashinfer
|
||||
__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) {
|
||||
__half h_x = x;
|
||||
__half h_y = y;
|
||||
return __hmul(h_x, h_y);
|
||||
}
|
||||
|
||||
#endif
|
||||
101
sgl-kernel/include/hip_vec_dtypes.h
Normal file
101
sgl-kernel/include/hip_vec_dtypes.h
Normal file
@@ -0,0 +1,101 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d
|
||||
|
||||
#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
struct vec_t;
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src);
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src);
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
struct vec_t {
|
||||
SGL_HIP_INLINE float_t& operator[](size_t i);
|
||||
SGL_HIP_INLINE const float_t& operator[](size_t i) const;
|
||||
SGL_HIP_INLINE float_t* ptr();
|
||||
|
||||
SGL_HIP_INLINE void load(const float_t* ptr);
|
||||
SGL_HIP_INLINE void store(float_t* ptr) const;
|
||||
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src);
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr);
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const;
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
// **** impl *****
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_load_impl(vec_t<dstDtype, vec_size>& dst, const srcDtype* src_ptr) {
|
||||
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
|
||||
dst.load(src_ptr);
|
||||
} else {
|
||||
vec_t<srcDtype, vec_size> tmp;
|
||||
tmp.load(src_ptr);
|
||||
dst.cast_from(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename srcDtype, typename dstDtype, size_t vec_size>
|
||||
SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t<srcDtype, vec_size>& src) {
|
||||
if constexpr (std::is_same<srcDtype, dstDtype>::value) {
|
||||
src.store(dst_ptr);
|
||||
} else {
|
||||
vec_t<dstDtype, vec_size> tmp;
|
||||
tmp.cast_from(src);
|
||||
tmp.store(dst_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
|
||||
template <typename float_t, size_t vec_size>
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void vec_t<float_t, vec_size>::cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#include "impl/hip_vec_bf16_impl.h"
|
||||
#include "impl/hip_vec_fp32_impl.h"
|
||||
#include "impl/hip_vec_half_impl.h"
|
||||
#endif
|
||||
177
sgl-kernel/include/impl/hip_vec_bf16_impl.h
Normal file
177
sgl-kernel/include/impl/hip_vec_bf16_impl.h
Normal file
@@ -0,0 +1,177 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_common.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
using nv_bfloat16 = __hip_bfloat16;
|
||||
using nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) {
|
||||
__hip_bfloat162 t;
|
||||
t.x = x;
|
||||
t.y = y;
|
||||
return t;
|
||||
}
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
// nv_bfloat16 x 1
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 1> {
|
||||
nv_bfloat16 data;
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// nv_bfloat16 x 2
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 2> {
|
||||
nv_bfloat162 data;
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
|
||||
data = *((nv_bfloat162*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
|
||||
*((nv_bfloat162*)ptr) = data;
|
||||
}
|
||||
|
||||
template <>
|
||||
struct vec_t<nv_bfloat16, 4> {
|
||||
uint2 data;
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr);
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
|
||||
data = *((uint2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
|
||||
*((uint2*)ptr) = data;
|
||||
}
|
||||
|
||||
// nv_bfloat16 x 8 or more
|
||||
|
||||
template <size_t vec_size>
|
||||
struct vec_t<nv_bfloat16, vec_size> {
|
||||
uint4 data[vec_size / 8];
|
||||
|
||||
SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) {
|
||||
return ((nv_bfloat16*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const {
|
||||
return ((const nv_bfloat16*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE nv_bfloat16* ptr() {
|
||||
return reinterpret_cast<nv_bfloat16*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const nv_bfloat16* ptr) {
|
||||
#pragma unoll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
data[i] = ((uint4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(nv_bfloat16* ptr) const {
|
||||
#pragma unoll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
((uint4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#endif
|
||||
129
sgl-kernel/include/impl/hip_vec_fp32_impl.h
Normal file
129
sgl-kernel/include/impl/hip_vec_fp32_impl.h
Normal file
@@ -0,0 +1,129 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_common.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
template <>
|
||||
struct vec_t<float, 1> {
|
||||
float data;
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr);
|
||||
SGL_HIP_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 1>::load(const float* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 1>::store(float* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// float x 2
|
||||
|
||||
template <>
|
||||
struct vec_t<float, 2> {
|
||||
float2 data;
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr);
|
||||
SGL_HIP_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 2>::load(const float* ptr) {
|
||||
data = *((float2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<float, 2>::store(float* ptr) const {
|
||||
*((float2*)ptr) = data;
|
||||
}
|
||||
|
||||
// float x 4 or more
|
||||
template <size_t vec_size>
|
||||
struct vec_t<float, vec_size> {
|
||||
float4 data[vec_size / 4];
|
||||
|
||||
SGL_HIP_INLINE float& operator[](size_t i) {
|
||||
return ((float*)(data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE float* ptr() {
|
||||
return reinterpret_cast<float*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const float* ptr) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
data[i] = ((float4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(float* ptr) const {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
((float4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
|
||||
#endif
|
||||
172
sgl-kernel/include/impl/hip_vec_half_impl.h
Normal file
172
sgl-kernel/include/impl/hip_vec_half_impl.h
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
|
||||
#if USE_ROCM
|
||||
|
||||
#include <hip/hip_common.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
|
||||
using half = __half;
|
||||
using half2 = __half2;
|
||||
|
||||
namespace sgl_hip {
|
||||
|
||||
// half x 1
|
||||
template <>
|
||||
struct vec_t<half, 1> {
|
||||
half data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 1>::load(const half* ptr) {
|
||||
data = *ptr;
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 1>::store(half* ptr) const {
|
||||
*ptr = data;
|
||||
}
|
||||
|
||||
// half x 2
|
||||
template <>
|
||||
struct vec_t<half, 2> {
|
||||
half2 data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 2>::load(const half* ptr) {
|
||||
data = *((half2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 2>::store(half* ptr) const {
|
||||
*((half2*)ptr) = data;
|
||||
}
|
||||
|
||||
// half x 4
|
||||
|
||||
template <>
|
||||
struct vec_t<half, 4> {
|
||||
uint2 data;
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)(&data))[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr);
|
||||
SGL_HIP_INLINE void store(half* ptr) const;
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, 4>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 4>::load(const half* ptr) {
|
||||
data = *((uint2*)ptr);
|
||||
}
|
||||
|
||||
SGL_HIP_INLINE void vec_t<half, 4>::store(half* ptr) const {
|
||||
*((uint2*)ptr) = data;
|
||||
}
|
||||
|
||||
// half x 8 or more
|
||||
|
||||
template <size_t vec_size>
|
||||
struct vec_t<half, vec_size> {
|
||||
uint4 data[vec_size / 8];
|
||||
|
||||
SGL_HIP_INLINE half& operator[](size_t i) {
|
||||
return ((half*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE const half& operator[](size_t i) const {
|
||||
return ((const half*)data)[i];
|
||||
}
|
||||
SGL_HIP_INLINE half* ptr() {
|
||||
return reinterpret_cast<half*>(&data);
|
||||
}
|
||||
SGL_HIP_INLINE void load(const half* ptr) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
data[i] = ((uint4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SGL_HIP_INLINE void store(half* ptr) const {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 8; ++i) {
|
||||
((uint4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T>
|
||||
SGL_HIP_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sgl_hip
|
||||
#endif
|
||||
@@ -138,9 +138,10 @@ void sgl_fused_add_rmsnorm(
|
||||
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
|
||||
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
@@ -151,6 +152,9 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
#endif
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
|
||||
@@ -19,7 +19,20 @@ limitations under the License.
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <sstream>
|
||||
#ifdef USE_ROCM
|
||||
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = __hip_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Adapt from FlashInfer
|
||||
@@ -31,7 +44,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_F16(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_F16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_BF16
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
@@ -41,7 +54,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_BF16
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E4M3
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
|
||||
@@ -51,7 +64,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_FP8_E4M3
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E5M2
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
|
||||
@@ -61,7 +74,7 @@ limitations under the License.
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
|
||||
#endif
|
||||
#endif // FLASHINFER_ENABLE_FP8_E5M2
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
@@ -197,7 +210,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
inline bool is_float8_tensor(const at::Tensor& tensor) {
|
||||
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
|
||||
}
|
||||
#endif
|
||||
#endif // USE_ROCM
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
@@ -267,7 +280,6 @@ inline bool getEnvEnablePDL() {
|
||||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
@@ -284,7 +296,6 @@ inline bool getEnvEnablePDL() {
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
@@ -297,52 +308,99 @@ inline bool getEnvEnablePDL() {
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize // 64
|
||||
#endif
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#include "hip_math_def.h"
|
||||
#include "hip_vec_dtypes.h"
|
||||
|
||||
#else
|
||||
|
||||
template <typename srcDtype>
|
||||
__device__ __forceinline__ float castToFloat(srcDtype val) {
|
||||
return static_cast<srcDtype>(val);
|
||||
}
|
||||
|
||||
template <typename dstDtype>
|
||||
__device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
return static_cast<dstDtype>(val);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// add FP8 support
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
|
||||
#else // USE_ROCM
|
||||
|
||||
#if HIP_FP8_TYPE_FNUZ
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
#else
|
||||
#if HIP_FP8_TYPE_E4M3
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#error "fp8 is not supported in this processor (arch < gfx942)."
|
||||
#endif // HIP_FP8_TYPE_E4M3
|
||||
#endif // HIP_FP8_TYPE_FNUZ
|
||||
#endif // USE_ROCM
|
||||
|
||||
#define FULL_MASK 0xffffffff
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
#ifndef USE_ROCM
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
return old;
|
||||
#else
|
||||
int* addr_as_i = (int*)addr;
|
||||
int old = *addr_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed))));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
|
||||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
|
||||
return max_value;
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float max_value) {
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) max_value = warpReduceMax(max_value);
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return max_value;
|
||||
return value;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Pads to a multiple of `alignment` rows.
|
||||
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
|
||||
|
||||
Reference in New Issue
Block a user