Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)
Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
import sgl_kernel
|
||||
from sgl_kernel import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
prepare_moe_input,
|
||||
@@ -151,8 +152,8 @@ def cutlass_fused_experts_fp8(
|
||||
k,
|
||||
)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map]
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
@@ -206,9 +207,9 @@ def cutlass_fused_experts_fp8(
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
return (
|
||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
result = torch.empty((m, k), device=device, dtype=out_dtype)
|
||||
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
3
sgl-kernel/csrc/common_extension.cc
Normal file → Executable file
3
sgl-kernel/csrc/common_extension.cc
Normal file → Executable file
@@ -195,7 +195,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
12
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file → Executable file
12
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file → Executable file
@@ -174,10 +174,10 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
// bool use_small_config = a[0].size(0) <= 128;
|
||||
struct MmaConfig1 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using MmaTileShape = Shape<_256, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
@@ -214,7 +214,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
if (a.size(0) <= 512 && a.size(1) >= 2048) {
|
||||
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
@@ -247,7 +247,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
|
||||
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
|
||||
114
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file → Executable file
114
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file → Executable file
@@ -252,3 +252,117 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
|
||||
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void apply_shuffle_mul_sum_kernel(
|
||||
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
|
||||
scalar_t* __restrict__ output_tensor, // [m, row_stride]
|
||||
const int32_t* __restrict__ permutation, // [m * topk]
|
||||
int m,
|
||||
int topk,
|
||||
int row_stride,
|
||||
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
|
||||
{
|
||||
int i = blockIdx.x; // [0, m * topk)
|
||||
int d = threadIdx.x; // [0, row_stride)
|
||||
|
||||
if (i >= m || d >= row_stride) return;
|
||||
|
||||
scalar_t sum_val = 0.0;
|
||||
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int index_2d = i * topk + j;
|
||||
int src_row = permutation[index_2d];
|
||||
if (src_row >= m) continue;
|
||||
|
||||
scalar_t val = input_tensor[src_row * row_stride + d];
|
||||
|
||||
scalar_t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[index_2d];
|
||||
}
|
||||
|
||||
sum_val += factor * val;
|
||||
}
|
||||
|
||||
output_tensor[i * row_stride + d] = sum_val;
|
||||
}
|
||||
|
||||
void get_apply_shuffle_mul_sum_caller(
|
||||
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
|
||||
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
|
||||
const torch::Tensor& permutation, // [m * topk], int32
|
||||
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
|
||||
{
|
||||
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
|
||||
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
|
||||
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");
|
||||
|
||||
int m = output_tensor.size(0);
|
||||
int topk = int(permutation.size(0) / m);
|
||||
int row_stride = output_tensor.size(1);
|
||||
|
||||
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
|
||||
|
||||
dim3 block(std::min(256, row_stride));
|
||||
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
|
||||
|
||||
const int32_t* perm_ptr = permutation.data_ptr<int32_t>();
|
||||
|
||||
void* factors_ptr = nullptr;
|
||||
if (factors_opt.has_value()) {
|
||||
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
|
||||
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
|
||||
factors_ptr = factors_opt->data_ptr();
|
||||
}
|
||||
|
||||
if (output_tensor.scalar_type() == at::ScalarType::Half) {
|
||||
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
|
||||
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
|
||||
input_tensor.data_ptr<at::Half>(),
|
||||
output_tensor.data_ptr<at::Half>(),
|
||||
perm_ptr,
|
||||
m,
|
||||
topk,
|
||||
row_stride,
|
||||
static_cast<const at::Half*>(factors_ptr));
|
||||
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
|
||||
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
|
||||
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
|
||||
input_tensor.data_ptr<c10::BFloat16>(),
|
||||
output_tensor.data_ptr<c10::BFloat16>(),
|
||||
perm_ptr,
|
||||
m,
|
||||
topk,
|
||||
row_stride,
|
||||
static_cast<const c10::BFloat16*>(factors_ptr));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
|
||||
*
|
||||
* This function performs the equivalent of the following PyTorch expression:
|
||||
*
|
||||
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
*
|
||||
* Specifically:
|
||||
* - `input` is shuffled using the `permutation` tensor.
|
||||
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
|
||||
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
|
||||
*
|
||||
* @param input Input tensor of shape (m * topk, k), representing c2.
|
||||
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
|
||||
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
|
||||
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
|
||||
*/
|
||||
void apply_shuffle_mul_sum(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& permutation,
|
||||
const std::optional<torch::Tensor>& factors) {
|
||||
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
|
||||
}
|
||||
|
||||
6
sgl-kernel/include/sgl_kernel_ops.h
Normal file → Executable file
6
sgl-kernel/include/sgl_kernel_ops.h
Normal file → Executable file
@@ -276,6 +276,12 @@ void ep_moe_post_reorder(
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
|
||||
|
||||
void apply_shuffle_mul_sum(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& permutation,
|
||||
const std::optional<torch::Tensor>& factors);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
|
||||
@@ -48,6 +48,7 @@ from sgl_kernel.gemm import (
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
ep_moe_post_reorder,
|
||||
ep_moe_pre_reorder,
|
||||
|
||||
@@ -178,6 +178,17 @@ def prepare_moe_input(
|
||||
)
|
||||
|
||||
|
||||
def apply_shuffle_mul_sum(
|
||||
input,
|
||||
output,
|
||||
permutation,
|
||||
factors,
|
||||
):
|
||||
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
|
||||
input, output, permutation, factors
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp4_group_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
|
||||
Reference in New Issue
Block a user