[sgl-kernel] Support float64 moe_sum_reduce cuda kernel (#11068)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
@@ -12,25 +13,36 @@
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float to_float(T x) {
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ float to_float<half>(half x) {
|
||||
return __half2float(x);
|
||||
template <typename T>
|
||||
__device__ __forceinline__ opmath_t<T> to_acc(T x) {
|
||||
return static_cast<opmath_t<T>>(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T from_float(float x) {
|
||||
__device__ __forceinline__ T from_acc(opmath_t<T> x) {
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half from_float<half>(float x) {
|
||||
__device__ __forceinline__ opmath_t<at::Half> to_acc<at::Half>(at::Half x) {
|
||||
return __half2float(__nv_half(x));
|
||||
}
|
||||
template <>
|
||||
__device__ __forceinline__ at::Half from_acc<at::Half>(opmath_t<at::Half> x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ opmath_t<at::BFloat16> to_acc<at::BFloat16>(at::BFloat16 x) {
|
||||
return __bfloat162float(__nv_bfloat16(x));
|
||||
}
|
||||
template <>
|
||||
__device__ __forceinline__ at::BFloat16 from_acc<at::BFloat16>(opmath_t<at::BFloat16> x) {
|
||||
return __float2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T ldg_cg(const T* p) {
|
||||
return __ldg(p);
|
||||
@@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk(
|
||||
const int64_t stride_token,
|
||||
const int64_t stride_topk,
|
||||
const int64_t out_stride_token,
|
||||
const float scale) {
|
||||
const opmath_t<scalar_t> scale) {
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int lane = threadIdx.x % 32;
|
||||
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
|
||||
if (t >= token_num) return;
|
||||
|
||||
for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) {
|
||||
float acc = 0.f;
|
||||
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
|
||||
const int64_t base = t * stride_token + d;
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOPK; ++k) {
|
||||
acc += to_float<scalar_t>(ldg_cg(&x[base + (int64_t)k * stride_topk]));
|
||||
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
|
||||
}
|
||||
acc *= scale;
|
||||
y[t * out_stride_token + d] = from_float<scalar_t>(acc);
|
||||
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,23 +151,79 @@ __global__ void moe_sum_reduce_kernel(
|
||||
const int64_t stride_token,
|
||||
const int64_t stride_topk,
|
||||
const int64_t out_stride_token,
|
||||
const float scale) {
|
||||
const opmath_t<scalar_t> scale) {
|
||||
for (int t = blockIdx.y; t < token_num; t += gridDim.y) {
|
||||
for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) {
|
||||
const int64_t base = t * stride_token + d;
|
||||
float acc = 0.f;
|
||||
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOPK; ++k) {
|
||||
acc += to_float<scalar_t>(x[base + (int64_t)k * stride_topk]);
|
||||
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
|
||||
}
|
||||
|
||||
acc *= scale;
|
||||
y[t * out_stride_token + d] = from_float<scalar_t>(acc);
|
||||
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------- general-topk fallback kernels --------------------
|
||||
// small-token
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_sum_reduce_kernel_general(
|
||||
const scalar_t* __restrict__ x,
|
||||
scalar_t* __restrict__ y,
|
||||
const int64_t token_num,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t stride_token,
|
||||
const int64_t stride_topk,
|
||||
const int64_t out_stride_token,
|
||||
const int topk_num,
|
||||
const opmath_t<scalar_t> scale) {
|
||||
for (int t = blockIdx.y; t < token_num; t += gridDim.y) {
|
||||
for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) {
|
||||
const int64_t base = t * stride_token + d;
|
||||
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
|
||||
#pragma unroll 1
|
||||
for (int k = 0; k < topk_num; ++k) {
|
||||
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
|
||||
}
|
||||
acc *= scale;
|
||||
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// warp-per-token
|
||||
template <typename scalar_t, int WARPS_PER_BLOCK>
|
||||
__global__ void moe_sum_reduce_kernel_warp_token_general(
|
||||
const scalar_t* __restrict__ x,
|
||||
scalar_t* __restrict__ y,
|
||||
const int64_t token_num,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t stride_token,
|
||||
const int64_t stride_topk,
|
||||
const int64_t out_stride_token,
|
||||
const int topk_num,
|
||||
const opmath_t<scalar_t> scale) {
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int lane = threadIdx.x % 32;
|
||||
const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id;
|
||||
if (t >= token_num) return;
|
||||
|
||||
for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) {
|
||||
opmath_t<scalar_t> acc = opmath_t<scalar_t>(0);
|
||||
const int64_t base = t * stride_token + d;
|
||||
#pragma unroll 1
|
||||
for (int k = 0; k < topk_num; ++k) {
|
||||
acc += to_acc<scalar_t>(x[base + (int64_t)k * stride_topk]);
|
||||
}
|
||||
acc *= scale;
|
||||
y[t * out_stride_token + d] = from_acc<scalar_t>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor) {
|
||||
TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor");
|
||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
@@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
|
||||
const int64_t in_stride_topk = input.stride(1);
|
||||
const int64_t out_stride_token = output.stride(0);
|
||||
|
||||
const float scale = static_cast<float>(routed_scaling_factor);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0);
|
||||
@@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const float scale = static_cast<float>(routed_scaling_factor);
|
||||
moe_sum_reduce_warp_per_token_vec_kernel<WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const at::BFloat16*>(input.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<at::BFloat16*>(output.data_ptr<at::BFloat16>()),
|
||||
@@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
|
||||
out_stride_token,
|
||||
scale);
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed");
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (bf16 vec) launch failed");
|
||||
return;
|
||||
}
|
||||
|
||||
const bool per_token_use_one_warp = (token_num > 128);
|
||||
|
||||
auto dispatch_topk = [&](auto&& launch_kernel) {
|
||||
switch (topk_num) {
|
||||
case 2:
|
||||
launch_kernel(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 4:
|
||||
launch_kernel(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 8:
|
||||
launch_kernel(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
case 9:
|
||||
launch_kernel(std::integral_constant<int, 9>{});
|
||||
break;
|
||||
default:
|
||||
launch_kernel(std::integral_constant<int, -1>{});
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if (!per_token_use_one_warp) {
|
||||
// ---------- small-token ----------
|
||||
const int block_size = 256;
|
||||
@@ -245,28 +292,55 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
|
||||
dim3 block(block_size);
|
||||
dim3 grid(static_cast<unsigned>(grid_x), static_cast<unsigned>(grid_y));
|
||||
|
||||
#define LAUNCH_SMALL_TOKEN_KERNEL(TOPK) \
|
||||
moe_sum_reduce_kernel<scalar_t_, TOPK><<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t_>(), \
|
||||
output.data_ptr<scalar_t_>(), \
|
||||
token_num, \
|
||||
hidden_dim, \
|
||||
in_stride_token, \
|
||||
in_stride_topk, \
|
||||
out_stride_token, \
|
||||
scale);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] {
|
||||
using scalar_t_ = scalar_t;
|
||||
using acc_t_ = opmath_t<scalar_t_>;
|
||||
const acc_t_ scale = static_cast<acc_t_>(routed_scaling_factor);
|
||||
|
||||
auto lauch_small_token_kernel = [&](auto topk_c) {
|
||||
constexpr int TK = decltype(topk_c)::value;
|
||||
|
||||
moe_sum_reduce_kernel<scalar_t_, TK><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t_>(),
|
||||
output.data_ptr<scalar_t_>(),
|
||||
token_num,
|
||||
hidden_dim,
|
||||
in_stride_token,
|
||||
in_stride_topk,
|
||||
out_stride_token,
|
||||
scale);
|
||||
};
|
||||
dispatch_topk(lauch_small_token_kernel);
|
||||
switch (topk_num) {
|
||||
case 2:
|
||||
LAUNCH_SMALL_TOKEN_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SMALL_TOKEN_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SMALL_TOKEN_KERNEL(8);
|
||||
break;
|
||||
case 9:
|
||||
LAUNCH_SMALL_TOKEN_KERNEL(9);
|
||||
break;
|
||||
default: // launch general kernel
|
||||
moe_sum_reduce_kernel_general<scalar_t_><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t_>(),
|
||||
output.data_ptr<scalar_t_>(),
|
||||
token_num,
|
||||
hidden_dim,
|
||||
in_stride_token,
|
||||
in_stride_topk,
|
||||
out_stride_token,
|
||||
static_cast<int>(topk_num),
|
||||
scale);
|
||||
}
|
||||
});
|
||||
#undef LAUNCH_SMALL_TOKEN_KERNEL
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (small-token) launch failed");
|
||||
|
||||
} else {
|
||||
// ---------- warp-token ----------
|
||||
// ---------- warp-per-token ----------
|
||||
constexpr int WARPS_PER_BLOCK = 4;
|
||||
constexpr int THREADS = WARPS_PER_BLOCK * 32;
|
||||
|
||||
@@ -279,25 +353,51 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
|
||||
dim3 block(THREADS);
|
||||
dim3 grid(static_cast<unsigned>(gx), static_cast<unsigned>(gy));
|
||||
|
||||
#define LAUNCH_WARP_PER_TOKEN_KERNEL(TOPK) \
|
||||
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TOPK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t_>(), \
|
||||
output.data_ptr<scalar_t_>(), \
|
||||
token_num, \
|
||||
hidden_dim, \
|
||||
in_stride_token, \
|
||||
in_stride_topk, \
|
||||
out_stride_token, \
|
||||
scale);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] {
|
||||
using scalar_t_ = scalar_t;
|
||||
using acc_t_ = opmath_t<scalar_t_>;
|
||||
const acc_t_ scale = static_cast<acc_t_>(routed_scaling_factor);
|
||||
|
||||
auto launch_large_token_kernel = [&](auto topk_c) {
|
||||
constexpr int TK = decltype(topk_c)::value;
|
||||
|
||||
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t_>(),
|
||||
output.data_ptr<scalar_t_>(),
|
||||
token_num,
|
||||
hidden_dim,
|
||||
in_stride_token,
|
||||
in_stride_topk,
|
||||
out_stride_token,
|
||||
scale);
|
||||
};
|
||||
dispatch_topk(launch_large_token_kernel);
|
||||
switch (topk_num) {
|
||||
case 2:
|
||||
LAUNCH_WARP_PER_TOKEN_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_WARP_PER_TOKEN_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_WARP_PER_TOKEN_KERNEL(8);
|
||||
break;
|
||||
case 9:
|
||||
LAUNCH_WARP_PER_TOKEN_KERNEL(9);
|
||||
break;
|
||||
default: // launch general kernel
|
||||
moe_sum_reduce_kernel_warp_token_general<scalar_t_, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t_>(),
|
||||
output.data_ptr<scalar_t_>(),
|
||||
token_num,
|
||||
hidden_dim,
|
||||
in_stride_token,
|
||||
in_stride_topk,
|
||||
out_stride_token,
|
||||
static_cast<int>(topk_num),
|
||||
scale);
|
||||
}
|
||||
});
|
||||
#undef LAUNCH_WARP_PER_TOKEN_KERNEL
|
||||
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (warp-token) launch failed");
|
||||
}
|
||||
TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user