adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

View File

@@ -0,0 +1,91 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
cutlass_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
return;
}
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
get_cutlass_w4a8_moe_mm_data_caller(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}

View File

@@ -0,0 +1,92 @@
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
__global__ void int4_fp8_get_group_gemm_starts(
int32_t* expert_offsets,
ElementA** a_offsets,
ElementB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
cutlass::bfloat16_t** b_scales_offsets,
ElementA* a_base_as_int,
ElementB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
cutlass::bfloat16_t* b_scales_base_as_int,
int64_t n,
int64_t k,
bool per_act_token,
bool per_out_ch) {
int expert_id = threadIdx.x;
int32_t expert_offset = expert_offsets[expert_id];
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
out_tensors.size(1), \
a_tensors.size(1), \
per_act_token, \
per_out_ch); \
}
namespace {
void run_int4_fp8_get_group_gemm_starts(
torch::Tensor const& expert_offsets,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace

View File

@@ -0,0 +1,386 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <type_traits>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
using namespace cute;
namespace {
enum class Sched { PP, CO };
template <int M, int N, int K, int A, int B, int C, Sched S>
struct SM90W4A8Config {
using KernelSchedule = std::conditional_t<
S == Sched::PP,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using EpilogueSchedule = std::conditional_t<
S == Sched::PP,
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>;
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>;
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <int M, int N, int K, int A, int B, int C>
using SM90_PP = SM90W4A8Config<M, N, K, A, B, C, Sched::PP>;
template <int M, int N, int K, int A, int B, int C>
using SM90_CO = SM90W4A8Config<M, N, K, A, B, C, Sched::CO>;
template <typename Config>
inline void invoke_gemm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
using GemmT = typename Config::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<GemmT>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1);
if (n == 4096 && k == 7168) {
// group gemm 1
if (m <= 4) {
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
invoke_gemm<SM90_CO<128, 16, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 2048) {
// group gemm 2
if (m <= 8) {
invoke_gemm<SM90_PP<64, 16, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 512 && k == 7168) {
// group gemm 1 for tp
if (m <= 4) {
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 256) {
// group gemm 2 for tp
if (m <= 8) {
invoke_gemm<SM90_PP<64, 16, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
invoke_gemm<SM90_PP<128, 32, 128, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else {
if (k % 512 == 0) {
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
}
}
} // namespace
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
dispatch_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
}

View File

@@ -0,0 +1,277 @@
#pragma once
/**
* @file w4a8_grouped_mm_c3x.cuh
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
* precision
*
* This file implements a grouped GEMM operation that multiplies FP8 matrices
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
* Cores for mixed precision arithmetic.
*
* Key features:
* - Supports grouped GEMM operations with multiple experts
* - Uses FP8 (e4m3) for matrix A
* - Uses INT4 quantization for matrix B with per-block scaling
* - Implements preprocessing for INT4 encoding and scale packing
* - Optimized for Hopper architecture with Tensor Core operations
*/
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
#include "w4a8_get_group_starts.cuh"
using namespace cute;
namespace {
// Type definitions
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementC = cutlass::bfloat16_t; // Output type
using ElementD = ElementC; // Output type
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// constexpr int TileShapeK = 512;
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
// using ClusterShape = Shape<_1, _1, _1>;
// Layout configurations
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Transposed layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm {
static constexpr int GroupSize = 128;
static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize;
using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC_Transpose*,
AlignmentC,
ElementD,
LayoutD_Transpose*,
AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
ArchTag,
OperatorClass,
cute::tuple<QuantType, ElementScalePacked>,
LayoutB_Transpose*,
AlignmentB,
MmaType,
LayoutA_Transpose*,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
// Define the final kernel and GEMM operation types
using GemmKernelScaleOnly =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
};
/**
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
*
* This function performs multiple GEMM operations in parallel where each
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
* applying per-channel scaling factors. It's designed for efficient execution
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
* mixed precision arithmetic.
*
* The function includes preprocessing steps for both INT4 tensors and scale
* factors to ensure optimal performance and correct operation.
*
* @param d_tensors Output tensor D with shape [total_m, total_n]
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
* [total_m, K]
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
* shape [E, N, K/2]
* @param a_scales Tensor containing A matrix scale factors
* @param b_scales Tensor containing B matrix scale factors with shape [E,
* K//512, N*4]
* @param expert_offsets Tensor containing expert offsets for determining group
* boundaries (int32)
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
* 3] (M, N, K for each group) (int32)
* @param a_strides Stride information for A tensors
* @param b_strides Stride information for B tensors
* @param d_strides Stride information for D tensors
* @param s_strides Stride information for scale tensors
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
*/
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
template <typename Gemm>
void cutlass_w4a8_group_gemm_caller(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
using Args = typename Gemm::GemmScaleOnly::Arguments;
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
// Check inputs
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
// Check tensor shapes
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_tensors.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f;
fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
run_int4_fp8_get_group_gemm_starts(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_tensors,
b_tensors,
d_tensors,
a_scales,
b_scales);
arguments = Args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const typename Gemm::ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)},
{fusion_args,
nullptr,
nullptr,
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
hw_info};
// Instantiate and run GEMM
typename Gemm::GemmScaleOnly gemm;
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM implementation not supported");
}
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM initialization failed");
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM execution failed");
}
}
} // namespace

View File

@@ -0,0 +1,79 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes_w4a8(
const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
__global__ void compute_expert_offsets_w4a8(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset;
}
}
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}

View File

@@ -0,0 +1,142 @@
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template <
typename ElementAB,
typename ElementC,
typename ElementAccumulator,
typename LayoutSFA,
typename LayoutSFB,
typename ScaleConfig>
__global__ void get_group_gemm_starts(
int32_t* expert_offsets,
ElementAB** a_offsets,
ElementAB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
ElementAccumulator** b_scales_offsets,
ElementAB* a_base_as_int,
ElementAB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
ElementAccumulator* b_scales_base_as_int,
LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int,
int* problem_sizes,
int* problem_sizes_transpose,
bool transpose = false) {
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
int m = problem_sizes[expert_id * 3];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];
if (transpose) {
problem_sizes_transpose[expert_id * 3] = n;
problem_sizes_transpose[expert_id * 3 + 1] = m;
problem_sizes_transpose[expert_id * 3 + 2] = k;
}
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t a_stride = 0;
int64_t b_stride = 0;
int64_t a_scale_stride = 0;
int64_t b_scale_stride = 0;
if (!transpose) {
a_stride = expert_offset * k;
b_stride = expert_id * k * n;
a_scale_stride = expert_offset * k / 128;
b_scale_stride = expert_id * k * n / 128 / 128;
} else {
a_stride = expert_id * k * n;
b_stride = expert_offset * k;
a_scale_stride = expert_id * k * n / 128 / 128;
b_scale_stride = expert_offset * k / 128;
}
a_offsets[expert_id] = a_base_as_int + a_stride;
b_offsets[expert_id] = b_base_as_int + b_stride;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + a_scale_stride;
b_scales_offsets[expert_id] = b_scales_base_as_int + b_scale_stride;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
if (!transpose) {
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
} else {
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(n, m, k, 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(n, m, k, 1));
}
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<int*>(problem_sizes.data_ptr()), \
static_cast<int*>(problem_sizes_transpose.data_ptr()), \
transpose); \
}
namespace {
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& layout_sfa,
torch::Tensor const& layout_sfb,
torch::Tensor const& problem_sizes,
torch::Tensor& problem_sizes_transpose,
bool transpose = false) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace

View File

@@ -0,0 +1,181 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template <typename scalar_t>
__global__ void ep_pre_reorder_cuda_kernel(
const scalar_t* __restrict__ input_ptr,
scalar_t* __restrict__ gateup_input_ptr,
const int* __restrict__ src2dst_ptr,
const int* __restrict__ topk_ids_ptr,
const float* __restrict__ a1_scales_ptr,
int start_expert_id,
int end_expert_id,
int topk,
int hidden_size,
bool use_per_token_if_dynamic) {
int token_idx = blockIdx.x;
int tid = threadIdx.x;
const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
const int* token_src2dst = src2dst_ptr + token_idx * topk;
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
float scale = 1.0f;
if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
scale = 1.0f / a1_scales_ptr[token_idx];
}
for (int k = 0; k < topk; ++k) {
int expert_id = token_topk_ids[k];
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
if (a1_scales_ptr != nullptr) {
if (!use_per_token_if_dynamic) {
scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
}
}
int dst_idx = token_src2dst[k];
scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
int vec_elements = (hidden_size / vec_size) * vec_size;
for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
vec_t input_vec, output_vec;
input_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float val = static_cast<float>(input_vec[i]);
output_vec[i] = static_cast<scalar_t>(val * scale);
}
output_vec.cast_store(dst_ptr + idx * vec_size);
}
for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) {
float val = static_cast<float>(src_ptr[idx]);
dst_ptr[idx] = static_cast<scalar_t>(val * scale);
}
}
}
template <typename scalar_t>
__global__ void ep_post_reorder_cuda_kernel(
const scalar_t* __restrict__ down_output_ptr,
scalar_t* __restrict__ output_ptr,
const int* __restrict__ src2dst_ptr,
const int* __restrict__ topk_ids_ptr,
const scalar_t* __restrict__ topk_weights_ptr,
int start_expert_id,
int end_expert_id,
int topk,
int hidden_size) {
const int token_idx = blockIdx.x;
const int tid = threadIdx.x;
const int* token_src2dst = src2dst_ptr + token_idx * topk;
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;
scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
const int vec_iters = hidden_size / vec_size;
for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
float acc[vec_size] = {0};
for (int k = 0; k < topk; ++k) {
const int expert_id = token_topk_ids[k];
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
const int src_row = token_src2dst[k];
const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
const float weight = static_cast<float>(token_topk_weights[k]);
vec_t src_vec;
src_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
acc[i] += static_cast<float>(src_vec[i]) * weight;
}
}
vec_t out_vec;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i)
out_vec[i] = static_cast<scalar_t>(acc[i]);
out_vec.cast_store(dst_ptr + idx * vec_size);
}
}
void ep_moe_pre_reorder(
torch::Tensor input,
torch::Tensor gateup_input,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor a1_scales,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk,
bool use_per_token_if_dynamic) {
const int total_blocks = input.size(0);
const int block_size = 512;
dim3 grid(total_blocks);
dim3 block(block_size);
int hidden_size = input.size(1);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
ep_pre_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<scalar_t*>(gateup_input.data_ptr()),
src2dst.data_ptr<int>(),
topk_ids.data_ptr<int>(),
a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
use_per_token_if_dynamic);
return true;
});
}
void ep_moe_post_reorder(
torch::Tensor down_output,
torch::Tensor output,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor topk_weights,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk) {
const int total_tokens = output.size(0);
const int block_size = 512;
dim3 grid(total_tokens);
dim3 block(block_size);
const int hidden_size = output.size(1);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
static_cast<scalar_t*>(down_output.data_ptr()),
static_cast<scalar_t*>(output.data_ptr()),
src2dst.data_ptr<int>(),
topk_ids.data_ptr<int>(),
static_cast<scalar_t*>(topk_weights.data_ptr()),
static_cast<int>(start_expert_id),
static_cast<int>(end_expert_id),
static_cast<int>(topk),
hidden_size);
return true;
});
}

View File

@@ -0,0 +1,115 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <THC/THCAtomics.cuh>
#include <algorithm>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
using namespace flashinfer;
template <typename scalar_t>
__device__ inline scalar_t silu_quantize(float x);
template <>
__device__ inline float silu_quantize<float>(float x) {
float y = x / (1.f + __expf(-x));
return y;
}
template <>
__device__ inline __half silu_quantize<__half>(float x) {
float y = x / (1.f + __expf(-x));
return __float2half_rn(y);
}
template <>
__device__ inline __nv_bfloat16 silu_quantize<__nv_bfloat16>(float x) {
float y = x / (1.f + __expf(-x));
return __float2bfloat16_rn(y);
}
template <typename scalar_t>
__global__ void ep_moe_act_and_mul_cuda_kernel(
const scalar_t* __restrict__ gateup_output,
scalar_t* __restrict__ down_input,
const int* __restrict__ reorder_topk_ids,
const float* __restrict__ scales,
int start_expert_id,
int end_expert_id,
int hidden_size) {
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int half_hidden_size = hidden_size >> 1;
const int expert_id = reorder_topk_ids[token_idx];
if (expert_id < start_expert_id || expert_id > end_expert_id) return;
const scalar_t* gate_output_ptr = gateup_output + static_cast<int64_t>(token_idx) * hidden_size;
const scalar_t* up_output_ptr = gate_output_ptr + half_hidden_size;
scalar_t* dst_ptr = down_input + static_cast<int64_t>(token_idx) * half_hidden_size;
scalar_t scale_q = static_cast<scalar_t>(scales ? (1.f / scales[expert_id - start_expert_id]) : 1.f);
const uint32_t vec_elements = half_hidden_size / vec_size;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < vec_elements; idx += stride) {
vec_t gate_vec, up_vec, out_vec;
gate_vec.load(gate_output_ptr + idx * vec_size);
up_vec.load(up_output_ptr + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float gate_f = static_cast<float>(gate_vec[i]);
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
scalar_t prod = gate_q * up_vec[i] * scale_q;
out_vec[i] = prod;
}
out_vec.store(dst_ptr + idx * vec_size);
}
const int64_t scalar_start = static_cast<int64_t>(vec_elements) * vec_size + thread_idx;
#pragma unroll 1
for (int64_t idx = scalar_start; idx < half_hidden_size; idx += stride) {
float gate_f = static_cast<float>(gate_output_ptr[idx]);
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
dst_ptr[idx] = gate_q * up_output_ptr[idx] * scale_q;
}
}
void ep_moe_silu_and_mul(
torch::Tensor gateup_output,
torch::Tensor down_input,
torch::Tensor reorder_topk_ids,
torch::Tensor scales,
int64_t start_expert_id,
int64_t end_expert_id) {
const int total_tokens = gateup_output.size(0);
const int hidden_size = gateup_output.size(1);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(gateup_output.scalar_type(), scalar_t, [&] {
dim3 grid(total_tokens);
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
const int half_hidden_size = hidden_size >> 1;
uint32_t threads = (half_hidden_size + vec_size - 1) / vec_size;
threads = std::max<uint32_t>(threads, 256);
threads = ((threads + 31) & ~31U);
dim3 block(std::min(threads, 1024U));
ep_moe_act_and_mul_cuda_kernel<scalar_t><<<grid, block>>>(
static_cast<scalar_t*>(gateup_output.data_ptr()),
static_cast<scalar_t*>(down_input.data_ptr()),
reorder_topk_ids.data_ptr<int>(),
scales.defined() ? scales.data_ptr<float>() : nullptr,
static_cast<int>(start_expert_id),
static_cast<int>(end_expert_id),
hidden_size);
return true;
});
}

View File

@@ -0,0 +1,810 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass_moe_helper.cu"
#include "utils.h"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void launch_sm90_fp8_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs,
const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs,
const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = OutType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using CustomEVTIdentity = // acc
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC, // Use void to avoid load Matrix C
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentC,
typename ScheduleConfig::EpilogueSchedule,
CustomEVTIdentity>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
AlignmentA,
ElementB,
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
AlignmentB,
ElementAccumulator,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
int num_experts = (int)expert_offsets.size(0);
Gemm gemm_op;
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
typename GemmKernel::EpilogueArguments epilogue_args{
{},
nullptr,
static_cast<StrideC*>(stride_c.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void launch_sm100_fp8_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs,
const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs,
const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
void,
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentC,
typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
AlignmentA,
ElementB,
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
AlignmentB,
ElementAccumulator,
typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
int num_experts = (int)expert_offsets.size(0);
// Create an instance of the GEMM
Gemm gemm_op;
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
// sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
hw_info.sm_count = 148;
typename GemmKernel::EpilogueArguments epilogue_args{
{},
nullptr,
static_cast<StrideC*>(stride_c.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void sm100_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
// Check the first matrix size to decide on the configuration
// Assuming all matrices in the group have similar size characteristics
// bool use_small_config = a[0].size(0) <= 128;
struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_256, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfig2 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfig3 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
b_t,
a_t,
output_t,
scales_b_t,
scales_a_t,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose,
true);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes_transpose,
expert_offsets,
workspace);
output = output_t.t();
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig2, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
run_get_group_gemm_starts<MmaConfig3::LayoutSFA, MmaConfig3::LayoutSFB, MmaConfig3::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig3, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
}
template <typename OutType>
void sm90_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
struct MmaConfigSmallM {
// Swap A/B
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>;
// TODO: Check Pingpong or Cooperative
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfigH20LargeK {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfigHx00AndH20SmallK {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
if (a.size(0) <= 2048) {
run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
b_t,
a_t,
output_t,
scales_b_t,
scales_a_t,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose,
true);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes_transpose,
expert_offsets,
workspace);
output = output_t.t();
} else {
if (is_h20_device && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<
MmaConfigH20LargeK::LayoutSFA,
MmaConfigH20LargeK::LayoutSFB,
MmaConfigH20LargeK::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<
MmaConfigHx00AndH20SmallK::LayoutSFA,
MmaConfigHx00AndH20SmallK::LayoutSFB,
MmaConfigHx00AndH20SmallK::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
}
}
/**
* @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
* with per-block scaling.
*
* This function dispatches to hardware-specific implementations (e.g., SM100 FP8)
* to compute:
* C_i = scale_a[i] * A_i * scale_b[i] * B_i
* for each expert group `i`, using input `problem_sizes` and `expert_offsets`
* to describe the individual matrix dimensions and their offsets.
*
* Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication.
* The output tensor is written with bfloat16 or half precision.
*
* @param output Output tensor (must be of type bfloat16 or half).
* @param a Input tensor A (must be kFloat8_e4m3fn).
* @param b Input tensor B (must be kFloat8_e4m3fn).
* @param scales_a Scaling factors for tensor A, float32 per expert group.
* @param scales_b Scaling factors for tensor B, float32 per expert group.
* @param stride_a Stride information for tensor A (int32).
* @param stride_b Stride information for tensor B (int32).
* @param stride_c Stride information for output tensor C (int32).
* @param layout_sfa Layout descriptor for A (int32), e.g., row-major/column-major.
* @param layout_sfb Layout descriptor for B (int32).
* @param problem_sizes 2D int32 tensor of shape (num_experts, 3), specifying (M, N, K)
* for each grouped matrix multiplication problem.
* @param expert_offsets 1D int32 tensor of size (num_experts), used to index into
* the grouped input tensors for dispatch.
* @note Performance Optimization:
* If the batch size (a.size(0)) is smaller than 512, the implementation
* will internally transpose input matrices to align with the optimal memory access
* pattern for better GPU efficiency. This transformation is done within the kernel.
*/
void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn");
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn");
TORCH_CHECK(
output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf,
"output must be bfloat16 or half");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be float32");
TORCH_CHECK(stride_a.scalar_type() == torch::kInt64, "stride_a must be int64");
TORCH_CHECK(stride_b.scalar_type() == torch::kInt64, "stride_b must be int64");
TORCH_CHECK(stride_c.scalar_type() == torch::kInt64, "stride_c must be int64");
TORCH_CHECK(layout_sfa.scalar_type() == torch::kInt32, "layout_sfa must be int32");
TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32");
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor");
TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor");
TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor");
TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor");
TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor");
TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor");
TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor");
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
bool can_implement = false;
auto sm_version = getSMVersion();
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version == 100
#if CUDA_VERSION >= 12090
|| sm_version == 103
#endif
) {
if (output.scalar_type() == torch::kBFloat16) {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
can_implement = true;
}
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if (sm_version == 90) {
if (output.scalar_type() == torch::kBFloat16) {
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else {
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
}
can_implement = true;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
}

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
KERNEL_FILE_TEMPLATE = (
"// auto generated by generate.py\n"
"// clang-format off\n"
"#pragma once\n\n"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"\n'
"{% endfor %}"
)
KERNEL_FILE_NAME = "kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
kernel_files = set()
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
kernel_files.add(filename)
kernel_files = list(kernel_files)
kernel_files.sort()
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
kernel_files=kernel_files
)
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()

View File

@@ -0,0 +1,41 @@
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
bool use_atomic_add, bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

View File

@@ -0,0 +1,90 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,110 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,110 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,90 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,110 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,110 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}

View File

@@ -0,0 +1,10 @@
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,366 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using Vec = int4;
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
}
#ifdef __CUDA_ARCH__
__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) {
int original = v;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up(v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
}
return v - original;
}
#endif
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel,
int32_t* __restrict__ cumsum,
bool pad_sorted_token_ids,
const int32_t scan_size) {
extern __shared__ int32_t smem[];
int32_t* shared_counts = smem; // [num_experts]
int32_t* prefix = shared_counts + num_experts; // [num_experts + 1]
int32_t* scan_buf = prefix + num_experts + 1; // [scan_size]
__shared__ int32_t s_total_tokens_post_pad;
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
if (tid < num_experts) {
shared_counts[tid] = 0;
}
__syncthreads();
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i] + 1;
atomicAdd(&shared_counts[expert_id], 1);
}
__syncthreads();
int32_t padded_count = 0;
if (tid < num_experts) {
int32_t count = shared_counts[tid];
padded_count = (count + block_size - 1) / block_size * block_size;
scan_buf[tid] = padded_count;
}
#ifndef __CUDA_ARCH__ // HIP
if (tid >= num_experts && tid < scan_size) {
scan_buf[tid] = 0;
}
__syncthreads();
// Blelloch scan
int offset = 1;
#pragma unroll
for (int d = scan_size >> 1; d > 0; d >>= 1) {
if (tid < d) {
int ai = offset * (2 * tid + 1) - 1;
int bi = offset * (2 * tid + 2) - 1;
scan_buf[bi] += scan_buf[ai];
}
offset <<= 1;
__syncthreads();
}
// down-sweep
if (tid == 0) {
prefix[num_experts] = scan_buf[scan_size - 1];
scan_buf[scan_size - 1] = 0;
}
__syncthreads();
#pragma unroll
for (int d = 1; d < scan_size; d <<= 1) {
offset >>= 1;
if (tid < d) {
int ai = offset * (2 * tid + 1) - 1;
int bi = offset * (2 * tid + 2) - 1;
if (bi < scan_size) {
int temp = scan_buf[ai];
scan_buf[ai] = scan_buf[bi];
scan_buf[bi] += temp;
}
}
__syncthreads();
}
if (tid < num_experts) {
prefix[tid] = scan_buf[tid];
}
if (tid == 0) {
s_total_tokens_post_pad = prefix[num_experts];
*total_tokens_post_pad = s_total_tokens_post_pad;
}
__syncthreads();
#else // CUDA
// Intra warp prefix sum
int32_t* warp_sums = scan_buf + scan_size; // [<= 32]
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid & (WARP_SIZE - 1);
const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE;
const int warp_sum = warp_exclusive_scan(padded_count) + padded_count;
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum;
__syncthreads();
// warp0 accumulate all the block's prefix sum
if (tid < WARP_SIZE) {
int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0;
int incl = warp_exclusive_scan(val) + val;
warp_sums[tid] = incl;
}
__syncthreads();
// Every thread obtains the whole block's sum
if (tid == 0) {
prefix[num_experts] = warp_sums[num_warps_for_scan - 1];
s_total_tokens_post_pad = prefix[num_experts];
*total_tokens_post_pad = s_total_tokens_post_pad;
}
__syncthreads();
// Fill 0 to scan_buf extended area (tid >= num_expert)
if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0;
__syncthreads();
// Perform 2 level exclusive-prefix-sum to scan_buf
int v = (tid < scan_size) ? scan_buf[tid] : 0;
int pre = warp_exclusive_scan(v);
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v;
__syncthreads();
if (warp_id == 0) {
int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0;
warp_sums[lane_id] = warp_exclusive_scan(val);
}
__syncthreads();
int offset = warp_sums[warp_id];
if (tid < scan_size) scan_buf[tid] = pre + offset;
__syncthreads();
// Write prefix[0..num_experts - 1] and cumsum
if (tid < num_experts) prefix[tid] = scan_buf[tid];
#endif
if (tid <= num_experts) {
cumsum[tid] = prefix[tid];
}
// fill expert_ids
const int32_t num_blocks = s_total_tokens_post_pad / block_size;
for (int32_t i = tid; i < num_blocks; i += stride) {
int32_t block_start = i * block_size;
int left = 0, right = num_experts;
while (left < right) {
int mid = (left + right) >> 1;
if (prefix[mid] <= block_start) {
left = mid + 1;
} else {
right = mid;
}
}
expert_ids[i] = left - 2;
}
if (pad_sorted_token_ids) {
Vec fill_vec;
fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
for (int32_t i = tid; i < total_vecs; i += stride) {
out_ptr[i] = fill_vec;
}
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel,
bool pad_sorted_token_ids) {
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
}
__syncthreads();
if (threadIdx.x < num_experts) {
tokens_cnts[threadIdx.x] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
}
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
}
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x - 1;
}
}
if (pad_sorted_token_ids) {
Vec fill_vec;
fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
for (int32_t i = tid; i < total_vecs; i += stride) {
out_ptr[i] = fill_vec;
}
}
__syncthreads();
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id];
}
}
void moe_align_block_size(
torch::Tensor topk_ids,
int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
pad_sorted_token_ids);
} else {
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
const size_t scan_size = next_pow2(num_experts);
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t);
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>(),
pad_sorted_token_ids,
scan_size);
const int block_threads = std::min(256, (int)threads);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
}
});
}

View File

@@ -0,0 +1,521 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
template <typename T, int N>
using AlignedArray = cutlass::AlignedArray<T, N>;
using bfloat16_t = cutlass::bfloat16_t;
using float16_t = cutlass::half_t;
using float32_t = float;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template <typename T>
__device__ inline bool cmp_gt(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return static_cast<float>(a) > static_cast<float>(b);
} else {
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return a > b;
}
}
template <typename T>
__device__ inline bool cmp_eq(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
return static_cast<float>(a) == static_cast<float>(b);
} else {
return a == b;
}
}
// Fixed constants common to both dynamic and static template versions:
static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template <typename T, int N>
using Array = AlignedArray<T, N>;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template <typename T>
using AccessType = AlignedArray<T, MAX_VPT>;
template <typename T, typename Params>
__device__ void moe_fused_gate_impl(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
if (thread_row >= num_rows) {
return;
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias);
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
int thread_group_idx = tidx % params.THREADS_PER_ROW;
int first_elt_read_by_thread = thread_group_idx * params.VPT;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> row_chunk;
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> bias_chunk;
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = vec_thread_read_ptr[0][ii];
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
}
__syncthreads();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
}
__syncthreads();
////////////////////// Add Bias //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int expert = first_elt_read_by_thread;
// local argmax
T max_val = static_cast<T>(-FLT_MAX);
T max_val_second = static_cast<T>(-FLT_MAX);
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val_second = max_val;
max_val = val;
} else if (cmp_gt(val, max_val_second)) {
max_val_second = val;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T max_sum = max_val + max_val_second;
// argmin reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max_sum =
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
// higher indices win
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
max_sum = other_max_sum;
expert = other_expert;
}
}
// clear the max value in the thread
if (k_idx < params.THREADS_PER_ROW - topk_group) {
int const thread_to_clear_in_group = expert / params.VPT;
if (thread_group_idx == thread_to_clear_in_group) {
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = static_cast<T>(FLT_MAX);
}
}
}
}
__syncthreads();
////////////////////// Topk //////////////////////
float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
// local argmax
T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread;
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
#pragma unroll
for (int ii = 1; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val = val;
expert = first_elt_read_by_thread + ii;
}
}
} else {
max_val = static_cast<T>(-FLT_MAX);
}
// argmax reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max =
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
// lower indices to win
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// accumulate sum for all elements
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
__syncthreads();
}
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
int64_t expert_offset = 0;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
if (num_fused_shared_experts > 1) {
for (int i = 1; i < num_fused_shared_experts; ++i) {
++last_idx;
++expert_offset;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
}
}
__syncthreads();
////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
struct KernelParams {
static constexpr int VPT = VPT_;
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
};
template <
typename T,
int VPT,
int NUM_EXPERTS,
int THREADS_PER_ROW,
int ROWS_PER_WARP,
int ROWS_PER_CTA,
int WARPS_PER_CTA>
__global__ void moe_fused_gate_kernel(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk, \
num_fused_shared_experts, \
routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct KernelParamsDynamic {
int VPT;
int NUM_EXPERTS;
int THREADS_PER_ROW;
int ROWS_PER_WARP;
int ROWS_PER_CTA;
int WARPS_PER_CTA;
};
template <typename T>
__global__ void moe_fused_gate_kernel_dynamic(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto output = torch::empty({num_rows, topk}, options);
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK(
num_experts % num_expert_group == 0,
"num_experts must be divisible by num_expert_group, but got ",
num_experts,
" / ",
num_expert_group);
int computed_vpt = num_experts / num_expert_group;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK(
computed_vpt <= MAX_VPT,
"Per group experts: num_experts / num_expert_group = (",
computed_vpt,
") exceeds the maximum supported (",
MAX_VPT,
")");
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool dispatched = false;
switch (num_experts) {
case 256:
if (num_expert_group == 8)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8);
} else if (num_expert_group == 16)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
}
break;
case 128:
if (num_expert_group == 4)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4);
} else if (num_expert_group == 8)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
}
break;
default:
break;
}
if (!dispatched) {
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_dynamic<bfloat16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
}
return {output, indices};
}

View File

@@ -0,0 +1,591 @@
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu
// which is originally adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
#include "utils.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N>
class alignas(Alignment) AlignedArray {
T data[N];
};
// ========================== Util functions to convert types ==========================
template <typename T>
__device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) {
return x;
} else {
return static_cast<float>(x);
}
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
const int thread_row_offset = blockIdx.x * num_cols;
float threadData(-FLT_MAX);
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x]) {
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(convert_to_float<T>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
int* indices,
const int num_experts,
const int k,
const int start_expert,
const int end_expert,
const bool renormalize) {
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
float row_sum_for_renormalize = 0;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
// Ignore experts the node isn't responsible for with expert parallelism
const int expert = result_kvp.key;
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
row_sum_for_renormalize += result_kvp.value;
}
__syncthreads();
}
if (renormalize && threadIdx.x == 0) {
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] * row_sum_for_renormalize_inv;
}
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small power of 2.
2) This implementation assumes k is small, but will work for any k.
*/
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const T* input,
const bool* finished,
float* output,
const int num_rows,
int* indices,
const int k,
const int start_expert,
const int end_expert,
const bool renormalize) {
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) {
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<T, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
T row_chunk_temp[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
float row_chunk[VPT];
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii) {
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float row_sum_for_renormalize = 0;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW);
int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0) {
// Add a guard to ignore experts not included by this node
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
row_sum_for_renormalize += max_val;
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
// Fuse renormalization of topk_weights into this kernel
if (renormalize && thread_group_idx == 0) {
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * row_sum_for_renormalize_inv;
}
}
}
namespace detail {
// Constructs some constants needed to partition the work across threads at compile time.
template <typename T, int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <typename T, int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(
const T* input,
const bool* finished,
float* output,
int* indices,
const int num_rows,
const int k,
const int start_expert,
const int end_expert,
const bool renormalize,
cudaStream_t stream) {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize);
}
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
template <typename T>
void topkGatingSoftmaxKernelLauncher(
const T* gating_output,
float* topk_weights,
int* topk_indices,
float* softmax_workspace,
const int num_tokens,
const int num_experts,
const int topk,
const bool renormalize,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB);
break;
case 2:
LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB);
break;
case 4:
LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB);
break;
case 8:
LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB);
break;
case 16:
LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB);
break;
case 32:
LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB);
break;
case 64:
LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB);
break;
case 128:
LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB);
break;
case 256:
LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB);
break;
default: {
TORCH_CHECK(
softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256;
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
}
}
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& gating_output,
const bool renormalize) // [num_tokens, num_experts]
{
// Check data type
TORCH_CHECK(
gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
gating_output.scalar_type() == at::ScalarType::BFloat16,
"gating_output must be float32, float16, or bfloat16");
// Check dimensions
TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]");
TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]");
TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]");
// Check shapes
TORCH_CHECK(
gating_output.size(0) == topk_weights.size(0),
"First dimension of topk_weights must match num_tokens in gating_output");
TORCH_CHECK(
gating_output.size(0) == topk_indices.size(0),
"First dimension of topk_indices must match num_tokens in gating_output");
TORCH_CHECK(
topk_weights.size(-1) == topk_indices.size(-1),
"Second dimension of topk_indices must match topk in topk_weights");
TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts");
const int num_experts = static_cast<int>(gating_output.size(-1));
const int num_tokens = static_cast<int>(gating_output.size(0));
const int topk = static_cast<int>(topk_weights.size(-1));
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace =
torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
const at::ScalarType dtype = gating_output.scalar_type();
if (dtype == at::ScalarType::Float) {
topkGatingSoftmaxKernelLauncher<float>(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
renormalize,
stream);
} else if (dtype == at::ScalarType::Half) {
topkGatingSoftmaxKernelLauncher<__half>(
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
renormalize,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
renormalize,
stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
}
}

View File

@@ -0,0 +1,471 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
#include <cassert>
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h"
using namespace cute;
template <
typename ElementAB,
typename ElementC,
typename ElementSF,
typename ElementAccumulator,
typename LayoutSFA,
typename LayoutSFB,
typename ScaleConfig>
__global__ void __get_group_gemm_starts(
ElementAB** a_offsets,
ElementAB** b_offsets,
ElementC** out_offsets,
ElementSF** a_scales_offsets,
ElementSF** b_scales_offsets,
ElementAccumulator** alpha_offsets,
LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int,
ElementAB* a_base_as_int,
ElementAB* b_base_as_int,
ElementC* out_base_as_int,
ElementSF* a_scales_base_as_int,
ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int,
const int32_t* expert_offsets,
const int32_t* sf_offsets,
const int32_t* problem_sizes_as_shapes,
const int K,
const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
// size for block in block scale.
int64_t group_size = 16;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \
ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
K, \
N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts,
const torch::Tensor& b_starts,
const torch::Tensor& out_starts,
const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts,
const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& alphas,
torch::Tensor const& expert_offsets,
torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes,
int M,
int N,
int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
TORCH_CHECK(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape");
TORCH_CHECK(
K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t,
cutlass::float_ue4m3_t,
torch::kBFloat16,
cutlass::bfloat16_t,
LayoutSFA,
LayoutSFB,
ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets,
int M,
int N,
int K) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
EpilogueOperatorClass,
typename MMA1SMConfig::MmaTileShape,
ClusterShape,
Shape<_128, _64>,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
MainloopOperatorClass,
ElementA,
LayoutA*,
AlignmentA,
ElementB,
LayoutB*,
AlignmentB,
ElementAccumulator,
typename MMA1SMConfig::MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
alpha_ptrs,
layout_sfa,
layout_sfb,
a,
b,
output,
a_blockscale,
b_blockscales,
alphas,
expert_offsets,
sf_offsets,
problem_sizes,
M,
N,
K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(ab_strides.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(ab_strides.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(
b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
} else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, sgl-kernel must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above.");
#endif
}

View File

@@ -0,0 +1,392 @@
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <flashinfer/vec_dtypes.cuh>
#include <iostream>
#include "cutlass/array.h"
#include "utils.h"
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(
const int* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int64_t topk_length,
const int64_t n,
const int64_t k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
}
}
__global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int64_t num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
}
}
__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* blockscale_offsets,
int32_t* atomic_buffer,
const int64_t num_experts) {
int32_t tot_offset = 0;
int32_t tot_rounded_offset = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
int num_tokens = problem_sizes1[i * 3];
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
tot_offset += num_tokens;
tot_rounded_offset += rounded_num_tokens;
expert_offsets[i + 1] = tot_offset;
blockscale_offsets[i + 1] = tot_rounded_offset;
}
}
__global__ void compute_arg_sorts(
const int32_t* __restrict__ topk_ids,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer,
const int64_t topk_length,
const int64_t topk) {
int expert_id = blockIdx.x;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1);
input_permutation[start] = i / topk;
output_permutation[i] = start;
}
}
}
void get_moe_prepare_input_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
topk_ids.size(1));
}
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
get_moe_prepare_input_caller(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}
template <typename T>
__global__ void shuffleRowsKernel(
const T* input,
const int32_t* dst2src_map,
T* output,
int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
auto const start_offset = threadIdx.x;
auto const stride = blockDim.x;
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
#define DECLARE_SHUFFLE_ROWS(T) \
__global__ void shuffleRowsKernel( \
const T* input, \
const int32_t* dst2src_map, \
T* output, \
int64_t num_src_rows, \
int64_t num_dest_rows, \
int64_t num_cols);
DECLARE_SHUFFLE_ROWS(float);
DECLARE_SHUFFLE_ROWS(half);
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
DECLARE_SHUFFLE_ROWS(uint8_t);
#define SHUFFLE_ROWS(T) \
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
reinterpret_cast<const T*>(input), \
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
reinterpret_cast<T*>(output), \
num_src_rows, \
num_dst_rows, \
num_cols)
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
case T: \
SHUFFLE_ROWS(CUDA_T); \
break;
void shuffle_rows_caller(
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
TORCH_CHECK(
input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
uint32_t threads = 256;
int64_t num_dst_rows = output_tensor.size(0);
int64_t num_src_rows = input_tensor.size(0);
int64_t num_cols = input_tensor.size(1);
const void* input = input_tensor.data_ptr();
void* output = output_tensor.data_ptr();
switch (input_tensor.scalar_type()) {
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
default:
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
}
return;
}
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
return;
}
template <typename scalar_t>
__global__ void apply_shuffle_mul_sum_kernel(
const scalar_t* __restrict__ input_tensor, // [m * topk, k] (expert-major layout)
scalar_t* __restrict__ output_tensor, // [m, k] (token-major layout)
const int32_t* __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx)
int m,
int topk,
int row_stride,
const scalar_t* __restrict__ factors) // [m * topk] (topk_weights, token-major layout)
{
int i = blockIdx.x;
if (i >= m) {
return;
}
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using t = float;
using vec_t = flashinfer::vec_t<t, vec_size>;
int thread_idx = threadIdx.x;
int stride = blockDim.x;
for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
int d = d_vec_idx * vec_size;
vec_t sum_vec;
sum_vec.fill(0.0f);
for (int j = 0; j < topk; ++j) {
int token_major_idx = i * topk + j;
int src_row = permutation[token_major_idx];
vec_t val_vec;
val_vec.cast_load(input_tensor + src_row * row_stride + d);
t factor = 1.0;
if (factors != nullptr) {
factor = factors[token_major_idx];
}
#pragma unroll
for (int k = 0; k < vec_size; ++k) {
sum_vec[k] += factor * val_vec[k];
}
}
sum_vec.cast_store(output_tensor + i * row_stride + d);
}
// remainder part
int remainder_start = (row_stride / vec_size) * vec_size;
for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
t sum_val = 0.0;
for (int j = 0; j < topk; ++j) {
int token_major_idx = i * topk + j;
int src_row = permutation[token_major_idx];
t val = input_tensor[src_row * row_stride + d];
t factor = 1.0;
if (factors != nullptr) {
factor = factors[token_major_idx];
}
sum_val += factor * val;
}
output_tensor[i * row_stride + d] = sum_val;
}
}
void get_apply_shuffle_mul_sum_caller(
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
const torch::Tensor& permutation, // [m * topk], int32
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
{
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");
int m = output_tensor.size(0);
int topk = int(permutation.size(0) / m);
int row_stride = output_tensor.size(1);
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
auto scalar_type = output_tensor.scalar_type();
uint32_t vec_size = 16 / sizeof(scalar_type);
auto blockDim = std::min(row_stride / vec_size, 1024U);
dim3 block(blockDim);
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
const int32_t* perm_ptr = permutation.data_ptr<int32_t>();
void* factors_ptr = nullptr;
if (factors_opt.has_value()) {
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
factors_ptr = factors_opt->data_ptr();
}
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] {
apply_shuffle_mul_sum_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input_tensor.data_ptr()),
static_cast<scalar_t*>(output_tensor.data_ptr()),
perm_ptr,
m,
topk,
row_stride,
static_cast<const scalar_t*>(factors_ptr));
return true;
});
}
/**
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
*
* This function performs the equivalent of the following PyTorch expression:
*
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
*
* Specifically:
* - `input` is shuffled using the `permutation` tensor.
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
*
* @param input Input tensor of shape (m * topk, k), representing c2.
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
*/
void apply_shuffle_mul_sum(
const torch::Tensor& input,
torch::Tensor& output,
const torch::Tensor& permutation,
const std::optional<torch::Tensor>& factors) {
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
}