adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
cutlass_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
return;
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
get_cutlass_w4a8_moe_mm_data_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k);
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void int4_fp8_get_group_gemm_starts(
|
||||
int32_t* expert_offsets,
|
||||
ElementA** a_offsets,
|
||||
ElementB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementAccumulator** a_scales_offsets,
|
||||
cutlass::bfloat16_t** b_scales_offsets,
|
||||
ElementA* a_base_as_int,
|
||||
ElementB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
cutlass::bfloat16_t* b_scales_base_as_int,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
|
||||
out_tensors.size(1), \
|
||||
a_tensors.size(1), \
|
||||
per_act_token, \
|
||||
per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_int4_fp8_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
386
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
386
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
@@ -0,0 +1,386 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "w4a8_grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
enum class Sched { PP, CO };
|
||||
|
||||
template <int M, int N, int K, int A, int B, int C, Sched S>
|
||||
struct SM90W4A8Config {
|
||||
using KernelSchedule = std::conditional_t<
|
||||
S == Sched::PP,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
using EpilogueSchedule = std::conditional_t<
|
||||
S == Sched::PP,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>;
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <int M, int N, int K, int A, int B, int C>
|
||||
using SM90_PP = SM90W4A8Config<M, N, K, A, B, C, Sched::PP>;
|
||||
|
||||
template <int M, int N, int K, int A, int B, int C>
|
||||
using SM90_CO = SM90W4A8Config<M, N, K, A, B, C, Sched::CO>;
|
||||
|
||||
template <typename Config>
|
||||
inline void invoke_gemm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size) {
|
||||
using GemmT = typename Config::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<GemmT>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
|
||||
void dispatch_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
uint32_t const m = a_tensors.size(0) / topk;
|
||||
uint32_t const n = d_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n == 4096 && k == 7168) {
|
||||
// group gemm 1
|
||||
if (m <= 4) {
|
||||
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 7168 && k == 2048) {
|
||||
// group gemm 2
|
||||
if (m <= 8) {
|
||||
invoke_gemm<SM90_PP<64, 16, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 512 && k == 7168) {
|
||||
// group gemm 1 for tp
|
||||
if (m <= 4) {
|
||||
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 7168 && k == 256) {
|
||||
// group gemm 2 for tp
|
||||
if (m <= 8) {
|
||||
invoke_gemm<SM90_PP<64, 16, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
invoke_gemm<SM90_PP<128, 32, 128, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else {
|
||||
if (k % 512 == 0) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
dispatch_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
}
|
||||
277
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
277
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
@@ -0,0 +1,277 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @file w4a8_grouped_mm_c3x.cuh
|
||||
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
|
||||
* precision
|
||||
*
|
||||
* This file implements a grouped GEMM operation that multiplies FP8 matrices
|
||||
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
|
||||
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
|
||||
* Cores for mixed precision arithmetic.
|
||||
*
|
||||
* Key features:
|
||||
* - Supports grouped GEMM operations with multiple experts
|
||||
* - Uses FP8 (e4m3) for matrix A
|
||||
* - Uses INT4 quantization for matrix B with per-block scaling
|
||||
* - Implements preprocessing for INT4 encoding and scale packing
|
||||
* - Optimized for Hopper architecture with Tensor Core operations
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
|
||||
#include "w4a8_get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
// Type definitions
|
||||
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||
using ElementAccumulator = float; // Accumulator type
|
||||
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||
using ElementC = cutlass::bfloat16_t; // Output type
|
||||
using ElementD = ElementC; // Output type
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Architecture-specific configurations
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
// constexpr int TileShapeK = 512;
|
||||
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
|
||||
// using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
// Layout configurations
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Transposed layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
struct cutlass_3x_w4a8_group_gemm {
|
||||
static constexpr int GroupSize = 128;
|
||||
static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize;
|
||||
using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
LayoutC_Transpose*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD_Transpose*,
|
||||
AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
cute::tuple<QuantType, ElementScalePacked>,
|
||||
LayoutB_Transpose*,
|
||||
AlignmentB,
|
||||
MmaType,
|
||||
LayoutA_Transpose*,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
// Define the final kernel and GEMM operation types
|
||||
using GemmKernelScaleOnly =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
|
||||
*
|
||||
* This function performs multiple GEMM operations in parallel where each
|
||||
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
|
||||
* applying per-channel scaling factors. It's designed for efficient execution
|
||||
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
|
||||
* mixed precision arithmetic.
|
||||
*
|
||||
* The function includes preprocessing steps for both INT4 tensors and scale
|
||||
* factors to ensure optimal performance and correct operation.
|
||||
*
|
||||
* @param d_tensors Output tensor D with shape [total_m, total_n]
|
||||
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
|
||||
* [total_m, K]
|
||||
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
|
||||
* shape [E, N, K/2]
|
||||
* @param a_scales Tensor containing A matrix scale factors
|
||||
* @param b_scales Tensor containing B matrix scale factors with shape [E,
|
||||
* K//512, N*4]
|
||||
* @param expert_offsets Tensor containing expert offsets for determining group
|
||||
* boundaries (int32)
|
||||
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
|
||||
* 3] (M, N, K for each group) (int32)
|
||||
* @param a_strides Stride information for A tensors
|
||||
* @param b_strides Stride information for B tensors
|
||||
* @param d_strides Stride information for D tensors
|
||||
* @param s_strides Stride information for scale tensors
|
||||
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
|
||||
*/
|
||||
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
template <typename Gemm>
|
||||
void cutlass_w4a8_group_gemm_caller(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size) {
|
||||
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
using Args = typename Gemm::GemmScaleOnly::Arguments;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
// Check inputs
|
||||
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
|
||||
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
|
||||
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
|
||||
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
|
||||
// Check tensor shapes
|
||||
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
|
||||
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
|
||||
|
||||
// Check tensor types
|
||||
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
|
||||
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
|
||||
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_tensors.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha = 1.0f;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||
;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
run_int4_fp8_get_group_gemm_starts(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
d_tensors,
|
||||
a_scales,
|
||||
b_scales);
|
||||
|
||||
arguments = Args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const typename Gemm::ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
|
||||
static_cast<int>(chunk_size)},
|
||||
{fusion_args,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
|
||||
hw_info};
|
||||
|
||||
// Instantiate and run GEMM
|
||||
typename Gemm::GemmScaleOnly gemm;
|
||||
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM implementation not supported");
|
||||
}
|
||||
|
||||
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM initialization failed");
|
||||
}
|
||||
|
||||
status = gemm.run(stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM execution failed");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
@@ -0,0 +1,79 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes_w4a8(
|
||||
const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets_w4a8(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3 + 1];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
}
|
||||
142
sgl-kernel/csrc/moe/cutlass_moe_helper.cu
Normal file
142
sgl-kernel/csrc/moe/cutlass_moe_helper.cu
Normal file
@@ -0,0 +1,142 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <
|
||||
typename ElementAB,
|
||||
typename ElementC,
|
||||
typename ElementAccumulator,
|
||||
typename LayoutSFA,
|
||||
typename LayoutSFB,
|
||||
typename ScaleConfig>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets,
|
||||
ElementAB** a_offsets,
|
||||
ElementAB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets,
|
||||
ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int,
|
||||
LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int,
|
||||
int* problem_sizes,
|
||||
int* problem_sizes_transpose,
|
||||
bool transpose = false) {
|
||||
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
if (transpose) {
|
||||
problem_sizes_transpose[expert_id * 3] = n;
|
||||
problem_sizes_transpose[expert_id * 3 + 1] = m;
|
||||
problem_sizes_transpose[expert_id * 3 + 2] = k;
|
||||
}
|
||||
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t a_stride = 0;
|
||||
int64_t b_stride = 0;
|
||||
int64_t a_scale_stride = 0;
|
||||
int64_t b_scale_stride = 0;
|
||||
if (!transpose) {
|
||||
a_stride = expert_offset * k;
|
||||
b_stride = expert_id * k * n;
|
||||
a_scale_stride = expert_offset * k / 128;
|
||||
b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
} else {
|
||||
a_stride = expert_id * k * n;
|
||||
b_stride = expert_offset * k;
|
||||
a_scale_stride = expert_id * k * n / 128 / 128;
|
||||
b_scale_stride = expert_offset * k / 128;
|
||||
}
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + a_scale_stride;
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
if (!transpose) {
|
||||
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
} else {
|
||||
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(n, m, k, 1));
|
||||
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(n, m, k, 1));
|
||||
}
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes_transpose.data_ptr()), \
|
||||
transpose); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor& problem_sizes_transpose,
|
||||
bool transpose = false) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
181
sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
Normal file
181
sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
Normal file
@@ -0,0 +1,181 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ep_pre_reorder_cuda_kernel(
|
||||
const scalar_t* __restrict__ input_ptr,
|
||||
scalar_t* __restrict__ gateup_input_ptr,
|
||||
const int* __restrict__ src2dst_ptr,
|
||||
const int* __restrict__ topk_ids_ptr,
|
||||
const float* __restrict__ a1_scales_ptr,
|
||||
int start_expert_id,
|
||||
int end_expert_id,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
bool use_per_token_if_dynamic) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
|
||||
const int* token_src2dst = src2dst_ptr + token_idx * topk;
|
||||
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
|
||||
|
||||
float scale = 1.0f;
|
||||
|
||||
if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
|
||||
scale = 1.0f / a1_scales_ptr[token_idx];
|
||||
}
|
||||
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
int expert_id = token_topk_ids[k];
|
||||
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
|
||||
|
||||
if (a1_scales_ptr != nullptr) {
|
||||
if (!use_per_token_if_dynamic) {
|
||||
scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
|
||||
}
|
||||
}
|
||||
|
||||
int dst_idx = token_src2dst[k];
|
||||
scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||
|
||||
int vec_elements = (hidden_size / vec_size) * vec_size;
|
||||
for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
|
||||
vec_t input_vec, output_vec;
|
||||
input_vec.cast_load(src_ptr + idx * vec_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
float val = static_cast<float>(input_vec[i]);
|
||||
output_vec[i] = static_cast<scalar_t>(val * scale);
|
||||
}
|
||||
output_vec.cast_store(dst_ptr + idx * vec_size);
|
||||
}
|
||||
|
||||
for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) {
|
||||
float val = static_cast<float>(src_ptr[idx]);
|
||||
dst_ptr[idx] = static_cast<scalar_t>(val * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ep_post_reorder_cuda_kernel(
|
||||
const scalar_t* __restrict__ down_output_ptr,
|
||||
scalar_t* __restrict__ output_ptr,
|
||||
const int* __restrict__ src2dst_ptr,
|
||||
const int* __restrict__ topk_ids_ptr,
|
||||
const scalar_t* __restrict__ topk_weights_ptr,
|
||||
int start_expert_id,
|
||||
int end_expert_id,
|
||||
int topk,
|
||||
int hidden_size) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int* token_src2dst = src2dst_ptr + token_idx * topk;
|
||||
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
|
||||
const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;
|
||||
|
||||
scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||
|
||||
const int vec_iters = hidden_size / vec_size;
|
||||
for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
|
||||
float acc[vec_size] = {0};
|
||||
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
const int expert_id = token_topk_ids[k];
|
||||
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
|
||||
const int src_row = token_src2dst[k];
|
||||
const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
|
||||
const float weight = static_cast<float>(token_topk_weights[k]);
|
||||
|
||||
vec_t src_vec;
|
||||
src_vec.cast_load(src_ptr + idx * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
acc[i] += static_cast<float>(src_vec[i]) * weight;
|
||||
}
|
||||
}
|
||||
vec_t out_vec;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i)
|
||||
out_vec[i] = static_cast<scalar_t>(acc[i]);
|
||||
|
||||
out_vec.cast_store(dst_ptr + idx * vec_size);
|
||||
}
|
||||
}
|
||||
|
||||
void ep_moe_pre_reorder(
|
||||
torch::Tensor input,
|
||||
torch::Tensor gateup_input,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor a1_scales,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk,
|
||||
bool use_per_token_if_dynamic) {
|
||||
const int total_blocks = input.size(0);
|
||||
const int block_size = 512;
|
||||
dim3 grid(total_blocks);
|
||||
dim3 block(block_size);
|
||||
int hidden_size = input.size(1);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
ep_pre_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<scalar_t*>(gateup_input.data_ptr()),
|
||||
src2dst.data_ptr<int>(),
|
||||
topk_ids.data_ptr<int>(),
|
||||
a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
use_per_token_if_dynamic);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
void ep_moe_post_reorder(
|
||||
torch::Tensor down_output,
|
||||
torch::Tensor output,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor topk_weights,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk) {
|
||||
const int total_tokens = output.size(0);
|
||||
const int block_size = 512;
|
||||
dim3 grid(total_tokens);
|
||||
dim3 block(block_size);
|
||||
const int hidden_size = output.size(1);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
|
||||
ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||
static_cast<scalar_t*>(down_output.data_ptr()),
|
||||
static_cast<scalar_t*>(output.data_ptr()),
|
||||
src2dst.data_ptr<int>(),
|
||||
topk_ids.data_ptr<int>(),
|
||||
static_cast<scalar_t*>(topk_weights.data_ptr()),
|
||||
static_cast<int>(start_expert_id),
|
||||
static_cast<int>(end_expert_id),
|
||||
static_cast<int>(topk),
|
||||
hidden_size);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
115
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
Normal file
115
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
Normal file
@@ -0,0 +1,115 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <algorithm>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline scalar_t silu_quantize(float x);
|
||||
|
||||
template <>
|
||||
__device__ inline float silu_quantize<float>(float x) {
|
||||
float y = x / (1.f + __expf(-x));
|
||||
return y;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __half silu_quantize<__half>(float x) {
|
||||
float y = x / (1.f + __expf(-x));
|
||||
return __float2half_rn(y);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 silu_quantize<__nv_bfloat16>(float x) {
|
||||
float y = x / (1.f + __expf(-x));
|
||||
return __float2bfloat16_rn(y);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ep_moe_act_and_mul_cuda_kernel(
|
||||
const scalar_t* __restrict__ gateup_output,
|
||||
scalar_t* __restrict__ down_input,
|
||||
const int* __restrict__ reorder_topk_ids,
|
||||
const float* __restrict__ scales,
|
||||
int start_expert_id,
|
||||
int end_expert_id,
|
||||
int hidden_size) {
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
|
||||
const int half_hidden_size = hidden_size >> 1;
|
||||
const int expert_id = reorder_topk_ids[token_idx];
|
||||
|
||||
if (expert_id < start_expert_id || expert_id > end_expert_id) return;
|
||||
const scalar_t* gate_output_ptr = gateup_output + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
const scalar_t* up_output_ptr = gate_output_ptr + half_hidden_size;
|
||||
scalar_t* dst_ptr = down_input + static_cast<int64_t>(token_idx) * half_hidden_size;
|
||||
scalar_t scale_q = static_cast<scalar_t>(scales ? (1.f / scales[expert_id - start_expert_id]) : 1.f);
|
||||
|
||||
const uint32_t vec_elements = half_hidden_size / vec_size;
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < vec_elements; idx += stride) {
|
||||
vec_t gate_vec, up_vec, out_vec;
|
||||
gate_vec.load(gate_output_ptr + idx * vec_size);
|
||||
up_vec.load(up_output_ptr + idx * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
float gate_f = static_cast<float>(gate_vec[i]);
|
||||
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
|
||||
scalar_t prod = gate_q * up_vec[i] * scale_q;
|
||||
out_vec[i] = prod;
|
||||
}
|
||||
out_vec.store(dst_ptr + idx * vec_size);
|
||||
}
|
||||
|
||||
const int64_t scalar_start = static_cast<int64_t>(vec_elements) * vec_size + thread_idx;
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = scalar_start; idx < half_hidden_size; idx += stride) {
|
||||
float gate_f = static_cast<float>(gate_output_ptr[idx]);
|
||||
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
|
||||
dst_ptr[idx] = gate_q * up_output_ptr[idx] * scale_q;
|
||||
}
|
||||
}
|
||||
|
||||
void ep_moe_silu_and_mul(
|
||||
torch::Tensor gateup_output,
|
||||
torch::Tensor down_input,
|
||||
torch::Tensor reorder_topk_ids,
|
||||
torch::Tensor scales,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id) {
|
||||
const int total_tokens = gateup_output.size(0);
|
||||
const int hidden_size = gateup_output.size(1);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(gateup_output.scalar_type(), scalar_t, [&] {
|
||||
dim3 grid(total_tokens);
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
const int half_hidden_size = hidden_size >> 1;
|
||||
uint32_t threads = (half_hidden_size + vec_size - 1) / vec_size;
|
||||
threads = std::max<uint32_t>(threads, 256);
|
||||
threads = ((threads + 31) & ~31U);
|
||||
dim3 block(std::min(threads, 1024U));
|
||||
ep_moe_act_and_mul_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||
static_cast<scalar_t*>(gateup_output.data_ptr()),
|
||||
static_cast<scalar_t*>(down_input.data_ptr()),
|
||||
reorder_topk_ids.data_ptr<int>(),
|
||||
scales.defined() ? scales.data_ptr<float>() : nullptr,
|
||||
static_cast<int>(start_expert_id),
|
||||
static_cast<int>(end_expert_id),
|
||||
hidden_size);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
810
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file
810
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file
@@ -0,0 +1,810 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass_moe_helper.cu"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs,
|
||||
const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs,
|
||||
const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = OutType;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using CustomEVTIdentity = // acc
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC, // Use void to avoid load Matrix C
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
typename ScheduleConfig::EpilogueSchedule,
|
||||
CustomEVTIdentity>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
Gemm gemm_op;
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{},
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs,
|
||||
const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs,
|
||||
const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
void,
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
hw_info.device_id = 0;
|
||||
// sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
|
||||
hw_info.sm_count = 148;
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{},
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
// Check the first matrix size to decide on the configuration
|
||||
// Assuming all matrices in the group have similar size characteristics
|
||||
// bool use_small_config = a[0].size(0) <= 128;
|
||||
struct MmaConfig1 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_256, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
struct MmaConfig2 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
struct MmaConfig3 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
torch::Tensor output_t = output.t();
|
||||
torch::Tensor a_t = a.t();
|
||||
torch::Tensor b_t = b.transpose(1, 2);
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
b_t,
|
||||
a_t,
|
||||
output_t,
|
||||
scales_b_t,
|
||||
scales_a_t,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose,
|
||||
true);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes_transpose,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig2, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
run_get_group_gemm_starts<MmaConfig3::LayoutSFA, MmaConfig3::LayoutSFB, MmaConfig3::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig3, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
struct MmaConfigSmallM {
|
||||
// Swap A/B
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
// TODO: Check Pingpong or Cooperative
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfigH20LargeK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfigHx00AndH20SmallK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
torch::Tensor output_t = output.t();
|
||||
torch::Tensor a_t = a.t();
|
||||
torch::Tensor b_t = b.transpose(1, 2);
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
|
||||
if (a.size(0) <= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
b_t,
|
||||
a_t,
|
||||
output_t,
|
||||
scales_b_t,
|
||||
scales_a_t,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose,
|
||||
true);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes_transpose,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else {
|
||||
if (is_h20_device && a.size(1) > 128) {
|
||||
// For H20 with K > 128, use Pingpong Schedule
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigH20LargeK::LayoutSFA,
|
||||
MmaConfigH20LargeK::LayoutSFB,
|
||||
MmaConfigH20LargeK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFA,
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFB,
|
||||
MmaConfigHx00AndH20SmallK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
|
||||
* with per-block scaling.
|
||||
*
|
||||
* This function dispatches to hardware-specific implementations (e.g., SM100 FP8)
|
||||
* to compute:
|
||||
* C_i = scale_a[i] * A_i * scale_b[i] * B_i
|
||||
* for each expert group `i`, using input `problem_sizes` and `expert_offsets`
|
||||
* to describe the individual matrix dimensions and their offsets.
|
||||
*
|
||||
* Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication.
|
||||
* The output tensor is written with bfloat16 or half precision.
|
||||
*
|
||||
* @param output Output tensor (must be of type bfloat16 or half).
|
||||
* @param a Input tensor A (must be kFloat8_e4m3fn).
|
||||
* @param b Input tensor B (must be kFloat8_e4m3fn).
|
||||
* @param scales_a Scaling factors for tensor A, float32 per expert group.
|
||||
* @param scales_b Scaling factors for tensor B, float32 per expert group.
|
||||
* @param stride_a Stride information for tensor A (int32).
|
||||
* @param stride_b Stride information for tensor B (int32).
|
||||
* @param stride_c Stride information for output tensor C (int32).
|
||||
* @param layout_sfa Layout descriptor for A (int32), e.g., row-major/column-major.
|
||||
* @param layout_sfb Layout descriptor for B (int32).
|
||||
* @param problem_sizes 2D int32 tensor of shape (num_experts, 3), specifying (M, N, K)
|
||||
* for each grouped matrix multiplication problem.
|
||||
* @param expert_offsets 1D int32 tensor of size (num_experts), used to index into
|
||||
* the grouped input tensors for dispatch.
|
||||
* @note Performance Optimization:
|
||||
* If the batch size (a.size(0)) is smaller than 512, the implementation
|
||||
* will internally transpose input matrices to align with the optimal memory access
|
||||
* pattern for better GPU efficiency. This transformation is done within the kernel.
|
||||
*/
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(
|
||||
output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be float32");
|
||||
TORCH_CHECK(stride_a.scalar_type() == torch::kInt64, "stride_a must be int64");
|
||||
TORCH_CHECK(stride_b.scalar_type() == torch::kInt64, "stride_b must be int64");
|
||||
TORCH_CHECK(stride_c.scalar_type() == torch::kInt64, "stride_c must be int64");
|
||||
TORCH_CHECK(layout_sfa.scalar_type() == torch::kInt32, "layout_sfa must be int32");
|
||||
TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor");
|
||||
TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor");
|
||||
TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor");
|
||||
TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor");
|
||||
TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor");
|
||||
TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
|
||||
|
||||
bool can_implement = false;
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
if (sm_version == 100
|
||||
#if CUDA_VERSION >= 12090
|
||||
|| sm_version == 103
|
||||
#endif
|
||||
) {
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
can_implement = true;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (sm_version == 90) {
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
can_implement = true;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
129
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
129
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
KERNEL_FILE_TEMPLATE = (
|
||||
"// auto generated by generate.py\n"
|
||||
"// clang-format off\n"
|
||||
"#pragma once\n\n"
|
||||
"{% for kernel_file in kernel_files %}"
|
||||
'#include "{{ kernel_file }}"\n'
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
KERNEL_FILE_NAME = "kernel_marlin.cuh"
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
kernel_files = set()
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_files.add(filename)
|
||||
|
||||
kernel_files = list(kernel_files)
|
||||
kernel_files.sort()
|
||||
|
||||
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
|
||||
kernel_files=kernel_files
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
41
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
41
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gemm/marlin/marlin.cuh"
|
||||
#include "gemm/marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
90
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
Normal file
90
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
Normal file
@@ -0,0 +1,90 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
Normal file
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
Normal file
@@ -0,0 +1,110 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
Normal file
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
Normal file
@@ -0,0 +1,110 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
90
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
Normal file
90
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
Normal file
@@ -0,0 +1,90 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
Normal file
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
Normal file
@@ -0,0 +1,110 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
Normal file
110
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
Normal file
@@ -0,0 +1,110 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
@@ -0,0 +1,10 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel_bf16_ku4.cuh"
|
||||
#include "kernel_bf16_ku4b8.cuh"
|
||||
#include "kernel_bf16_ku8b128.cuh"
|
||||
#include "kernel_fp16_ku4.cuh"
|
||||
#include "kernel_fp16_ku4b8.cuh"
|
||||
#include "kernel_fp16_ku8b128.cuh"
|
||||
1805
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1805
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
1111
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
1111
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
366
sgl-kernel/csrc/moe/moe_align_kernel.cu
Normal file
366
sgl-kernel/csrc/moe/moe_align_kernel.cu
Normal file
@@ -0,0 +1,366 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define WARP_SIZE 64
|
||||
|
||||
#define VEC_SIZE 4
|
||||
using Vec = int4;
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i] + 1;
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) {
|
||||
int original = v;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
int n = __shfl_up(v, offset);
|
||||
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
||||
}
|
||||
return v - original;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum,
|
||||
bool pad_sorted_token_ids,
|
||||
const int32_t scan_size) {
|
||||
extern __shared__ int32_t smem[];
|
||||
int32_t* shared_counts = smem; // [num_experts]
|
||||
int32_t* prefix = shared_counts + num_experts; // [num_experts + 1]
|
||||
int32_t* scan_buf = prefix + num_experts + 1; // [scan_size]
|
||||
__shared__ int32_t s_total_tokens_post_pad;
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
if (tid < num_experts) {
|
||||
shared_counts[tid] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int expert_id = topk_ids[i] + 1;
|
||||
atomicAdd(&shared_counts[expert_id], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int32_t padded_count = 0;
|
||||
if (tid < num_experts) {
|
||||
int32_t count = shared_counts[tid];
|
||||
padded_count = (count + block_size - 1) / block_size * block_size;
|
||||
scan_buf[tid] = padded_count;
|
||||
}
|
||||
|
||||
#ifndef __CUDA_ARCH__ // HIP
|
||||
|
||||
if (tid >= num_experts && tid < scan_size) {
|
||||
scan_buf[tid] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Blelloch scan
|
||||
int offset = 1;
|
||||
#pragma unroll
|
||||
for (int d = scan_size >> 1; d > 0; d >>= 1) {
|
||||
if (tid < d) {
|
||||
int ai = offset * (2 * tid + 1) - 1;
|
||||
int bi = offset * (2 * tid + 2) - 1;
|
||||
scan_buf[bi] += scan_buf[ai];
|
||||
}
|
||||
offset <<= 1;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// down-sweep
|
||||
if (tid == 0) {
|
||||
prefix[num_experts] = scan_buf[scan_size - 1];
|
||||
scan_buf[scan_size - 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int d = 1; d < scan_size; d <<= 1) {
|
||||
offset >>= 1;
|
||||
if (tid < d) {
|
||||
int ai = offset * (2 * tid + 1) - 1;
|
||||
int bi = offset * (2 * tid + 2) - 1;
|
||||
if (bi < scan_size) {
|
||||
int temp = scan_buf[ai];
|
||||
scan_buf[ai] = scan_buf[bi];
|
||||
scan_buf[bi] += temp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < num_experts) {
|
||||
prefix[tid] = scan_buf[tid];
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
s_total_tokens_post_pad = prefix[num_experts];
|
||||
*total_tokens_post_pad = s_total_tokens_post_pad;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#else // CUDA
|
||||
|
||||
// Intra warp prefix sum
|
||||
int32_t* warp_sums = scan_buf + scan_size; // [<= 32]
|
||||
const int warp_id = tid / WARP_SIZE;
|
||||
const int lane_id = tid & (WARP_SIZE - 1);
|
||||
const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int warp_sum = warp_exclusive_scan(padded_count) + padded_count;
|
||||
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
// warp0 accumulate all the block's prefix sum
|
||||
if (tid < WARP_SIZE) {
|
||||
int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0;
|
||||
int incl = warp_exclusive_scan(val) + val;
|
||||
warp_sums[tid] = incl;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Every thread obtains the whole block's sum
|
||||
if (tid == 0) {
|
||||
prefix[num_experts] = warp_sums[num_warps_for_scan - 1];
|
||||
s_total_tokens_post_pad = prefix[num_experts];
|
||||
*total_tokens_post_pad = s_total_tokens_post_pad;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Fill 0 to scan_buf extended area (tid >= num_expert)
|
||||
if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// Perform 2 level exclusive-prefix-sum to scan_buf
|
||||
int v = (tid < scan_size) ? scan_buf[tid] : 0;
|
||||
int pre = warp_exclusive_scan(v);
|
||||
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v;
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0;
|
||||
warp_sums[lane_id] = warp_exclusive_scan(val);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int offset = warp_sums[warp_id];
|
||||
if (tid < scan_size) scan_buf[tid] = pre + offset;
|
||||
__syncthreads();
|
||||
|
||||
// Write prefix[0..num_experts - 1] and cumsum
|
||||
if (tid < num_experts) prefix[tid] = scan_buf[tid];
|
||||
#endif
|
||||
|
||||
if (tid <= num_experts) {
|
||||
cumsum[tid] = prefix[tid];
|
||||
}
|
||||
// fill expert_ids
|
||||
const int32_t num_blocks = s_total_tokens_post_pad / block_size;
|
||||
for (int32_t i = tid; i < num_blocks; i += stride) {
|
||||
int32_t block_start = i * block_size;
|
||||
int left = 0, right = num_experts;
|
||||
while (left < right) {
|
||||
int mid = (left + right) >> 1;
|
||||
if (prefix[mid] <= block_start) {
|
||||
left = mid + 1;
|
||||
} else {
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
expert_ids[i] = left - 2;
|
||||
}
|
||||
|
||||
if (pad_sorted_token_ids) {
|
||||
Vec fill_vec;
|
||||
fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
|
||||
int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
|
||||
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||
for (int32_t i = tid; i < total_vecs; i += stride) {
|
||||
out_ptr[i] = fill_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
bool pad_sorted_token_ids) {
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[threadIdx.x] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (pad_sorted_token_ids) {
|
||||
Vec fill_vec;
|
||||
fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
|
||||
int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
|
||||
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||
for (int32_t i = tid; i < total_vecs; i += stride) {
|
||||
out_ptr[i] = fill_vec;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i] + 1;
|
||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor cumsum_buffer,
|
||||
bool pad_sorted_token_ids) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int threads = 1024;
|
||||
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
if (small_batch_expert_mode) {
|
||||
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
|
||||
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
|
||||
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
pad_sorted_token_ids);
|
||||
} else {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
const size_t scan_size = next_pow2(num_experts);
|
||||
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t);
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
pad_sorted_token_ids,
|
||||
scan_size);
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
topk_ids.numel());
|
||||
}
|
||||
});
|
||||
}
|
||||
521
sgl-kernel/csrc/moe/moe_fused_gate.cu
Normal file
521
sgl-kernel/csrc/moe/moe_fused_gate.cu
Normal file
@@ -0,0 +1,521 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <stdio.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <type_traits>
|
||||
template <typename T, int N>
|
||||
using AlignedArray = cutlass::AlignedArray<T, N>;
|
||||
using bfloat16_t = cutlass::bfloat16_t;
|
||||
using float16_t = cutlass::half_t;
|
||||
using float32_t = float;
|
||||
|
||||
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
|
||||
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
|
||||
template <typename T>
|
||||
__device__ inline bool cmp_gt(const T& a, const T& b) {
|
||||
if constexpr (std::is_same<T, at::Half>::value) {
|
||||
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
} else {
|
||||
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
|
||||
return a > b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline bool cmp_eq(const T& a, const T& b) {
|
||||
if constexpr (std::is_same<T, at::Half>::value) {
|
||||
return static_cast<float>(a) == static_cast<float>(b);
|
||||
} else {
|
||||
return a == b;
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed constants common to both dynamic and static template versions:
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int WARPS_PER_CTA = 6;
|
||||
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
|
||||
|
||||
// Create an alias for Array using AlignedArray
|
||||
template <typename T, int N>
|
||||
using Array = AlignedArray<T, N>;
|
||||
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
|
||||
template <typename T>
|
||||
using AccessType = AlignedArray<T, MAX_VPT>;
|
||||
|
||||
template <typename T, typename Params>
|
||||
__device__ void moe_fused_gate_impl(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
int64_t thread_row =
|
||||
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
|
||||
if (thread_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate topk_excluding_share_expert_fusion from topk
|
||||
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
|
||||
|
||||
// Cast pointers to type T:
|
||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||
auto* bias_ptr = reinterpret_cast<T*>(bias);
|
||||
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
|
||||
|
||||
int thread_group_idx = tidx % params.THREADS_PER_ROW;
|
||||
int first_elt_read_by_thread = thread_group_idx * params.VPT;
|
||||
|
||||
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
|
||||
// AccessType.
|
||||
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
Array<T, MAX_VPT> row_chunk;
|
||||
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
|
||||
|
||||
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
|
||||
Array<T, MAX_VPT> bias_chunk;
|
||||
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
|
||||
|
||||
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
|
||||
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
|
||||
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
|
||||
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
row_chunk[ii] = vec_thread_read_ptr[0][ii];
|
||||
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Sigmoid //////////////////////
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Add Bias //////////////////////
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
|
||||
}
|
||||
|
||||
////////////////////// Exclude Groups //////////////////////
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
|
||||
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
|
||||
int expert = first_elt_read_by_thread;
|
||||
// local argmax
|
||||
T max_val = static_cast<T>(-FLT_MAX);
|
||||
T max_val_second = static_cast<T>(-FLT_MAX);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
T val = bias_chunk[ii];
|
||||
|
||||
if (cmp_gt(val, max_val)) {
|
||||
max_val_second = max_val;
|
||||
max_val = val;
|
||||
} else if (cmp_gt(val, max_val_second)) {
|
||||
max_val_second = val;
|
||||
}
|
||||
}
|
||||
|
||||
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
|
||||
// to select expert groups
|
||||
T max_sum = max_val + max_val_second;
|
||||
|
||||
// argmin reduce
|
||||
#pragma unroll
|
||||
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
T other_max_sum =
|
||||
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
|
||||
|
||||
// higher indices win
|
||||
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
|
||||
max_sum = other_max_sum;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// clear the max value in the thread
|
||||
if (k_idx < params.THREADS_PER_ROW - topk_group) {
|
||||
int const thread_to_clear_in_group = expert / params.VPT;
|
||||
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
bias_chunk[ii] = static_cast<T>(FLT_MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Topk //////////////////////
|
||||
float output_sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
|
||||
// local argmax
|
||||
T max_val = bias_chunk[0];
|
||||
int expert = first_elt_read_by_thread;
|
||||
|
||||
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < params.VPT; ++ii) {
|
||||
T val = bias_chunk[ii];
|
||||
if (cmp_gt(val, max_val)) {
|
||||
max_val = val;
|
||||
expert = first_elt_read_by_thread + ii;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
max_val = static_cast<T>(-FLT_MAX);
|
||||
}
|
||||
|
||||
// argmax reduce
|
||||
#pragma unroll
|
||||
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
T other_max =
|
||||
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
|
||||
|
||||
// lower indices to win
|
||||
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
int thread_to_clear_in_group = expert / params.VPT;
|
||||
int64_t idx = topk * thread_row + k_idx;
|
||||
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
int expert_to_clear_in_thread = expert % params.VPT;
|
||||
|
||||
// clear the max value in the thread
|
||||
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
|
||||
|
||||
// store output
|
||||
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
|
||||
indices_ptr[idx] = static_cast<int32_t>(expert);
|
||||
}
|
||||
|
||||
// accumulate sum for all elements
|
||||
if (thread_group_idx == 0) {
|
||||
output_sum += output_ptr[idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
||||
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||
int64_t expert_offset = 0;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||
|
||||
if (num_fused_shared_experts > 1) {
|
||||
for (int i = 1; i < num_fused_shared_experts; ++i) {
|
||||
++last_idx;
|
||||
++expert_offset;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Rescale Output //////////////////////
|
||||
if (thread_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < topk; ++ii) {
|
||||
int64_t const idx = topk * thread_row + ii;
|
||||
output_ptr[idx] = output_ptr[idx] / output_sum;
|
||||
if (apply_routed_scaling_factor_on_output) {
|
||||
output_ptr[idx] *= routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Templated Kernel Version (using compile-time constants)
|
||||
//------------------------------------------------------------------------------
|
||||
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
|
||||
struct KernelParams {
|
||||
static constexpr int VPT = VPT_;
|
||||
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
|
||||
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
|
||||
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
|
||||
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
|
||||
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int VPT,
|
||||
int NUM_EXPERTS,
|
||||
int THREADS_PER_ROW,
|
||||
int ROWS_PER_WARP,
|
||||
int ROWS_PER_CTA,
|
||||
int WARPS_PER_CTA>
|
||||
__global__ void moe_fused_gate_kernel(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(
|
||||
input,
|
||||
bias,
|
||||
output_ptr,
|
||||
indices_ptr,
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
// Macro to compute compile-time constants and launch the kernel.
|
||||
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
|
||||
do { \
|
||||
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
|
||||
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
|
||||
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
|
||||
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
|
||||
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
|
||||
<<<num_blocks, block_dim, 0, stream>>>( \
|
||||
input.data_ptr(), \
|
||||
bias.data_ptr(), \
|
||||
output.data_ptr<float>(), \
|
||||
indices.data_ptr<int32_t>(), \
|
||||
num_rows, \
|
||||
topk_group, \
|
||||
topk, \
|
||||
num_fused_shared_experts, \
|
||||
routed_scaling_factor, \
|
||||
apply_routed_scaling_factor_on_output); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Dynamic Kernel Version (parameters computed at runtime)
|
||||
//------------------------------------------------------------------------------
|
||||
struct KernelParamsDynamic {
|
||||
int VPT;
|
||||
int NUM_EXPERTS;
|
||||
int THREADS_PER_ROW;
|
||||
int ROWS_PER_WARP;
|
||||
int ROWS_PER_CTA;
|
||||
int WARPS_PER_CTA;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void moe_fused_gate_kernel_dynamic(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t num_experts,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
||||
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
|
||||
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
|
||||
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
|
||||
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
|
||||
|
||||
moe_fused_gate_impl<T>(
|
||||
input,
|
||||
bias,
|
||||
output_ptr,
|
||||
indices_ptr,
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Host Launcher Function
|
||||
//------------------------------------------------------------------------------
|
||||
std::vector<at::Tensor> moe_fused_gate(
|
||||
at::Tensor& input,
|
||||
at::Tensor& bias,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto output = torch::empty({num_rows, topk}, options);
|
||||
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
|
||||
|
||||
// Compute grid dimensions based on runtime value for num_expert_group.
|
||||
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
|
||||
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
|
||||
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
|
||||
|
||||
// Check 1: Ensure that num_experts is a power of 2.
|
||||
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
|
||||
|
||||
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
|
||||
TORCH_CHECK(
|
||||
num_experts % num_expert_group == 0,
|
||||
"num_experts must be divisible by num_expert_group, but got ",
|
||||
num_experts,
|
||||
" / ",
|
||||
num_expert_group);
|
||||
|
||||
int computed_vpt = num_experts / num_expert_group;
|
||||
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
|
||||
// threads we can process.
|
||||
TORCH_CHECK(
|
||||
computed_vpt <= MAX_VPT,
|
||||
"Per group experts: num_experts / num_expert_group = (",
|
||||
computed_vpt,
|
||||
") exceeds the maximum supported (",
|
||||
MAX_VPT,
|
||||
")");
|
||||
|
||||
// Dispatch to templated kernel for known compile-time configurations.
|
||||
// We currently only support for:
|
||||
// Case 1: 256 experts, with 8 or 16 groups.
|
||||
// Case 2: 128 experts, with 4 or 8 groups.
|
||||
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
|
||||
bool dispatched = false;
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
if (num_expert_group == 8)
|
||||
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8);
|
||||
} else if (num_expert_group == 16)
|
||||
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
|
||||
}
|
||||
break;
|
||||
case 128:
|
||||
if (num_expert_group == 4)
|
||||
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4);
|
||||
} else if (num_expert_group == 8)
|
||||
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (!dispatched) {
|
||||
// Fallback to the dynamic kernel if none of the supported combinations match.
|
||||
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
moe_fused_gate_kernel_dynamic<bfloat16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
}
|
||||
}
|
||||
return {output, indices};
|
||||
}
|
||||
591
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
Normal file
591
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
Normal file
@@ -0,0 +1,591 @@
|
||||
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu
|
||||
// which is originally adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cuda/functional>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
// Define reduction operators based on CUDA version
|
||||
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
|
||||
#if CUDA_VERSION >= 12090
|
||||
using MaxReduceOp = cuda::maximum<>;
|
||||
using MinReduceOp = cuda::minimum<>;
|
||||
#else
|
||||
using MaxReduceOp = cub::Max;
|
||||
using MinReduceOp = cub::Min;
|
||||
#endif
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
T data[N];
|
||||
};
|
||||
|
||||
// ========================== Util functions to convert types ==========================
|
||||
template <typename T>
|
||||
__device__ float convert_to_float(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(x);
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
return x;
|
||||
} else {
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x]) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(
|
||||
const float* inputs_after_softmax,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
int* indices,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert,
|
||||
const bool renormalize) {
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
float row_sum_for_renormalize = 0;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
row_sum_for_renormalize += result_kvp.value;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (renormalize && threadIdx.x == 0) {
|
||||
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = output[idx] * row_sum_for_renormalize_inv;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
||||
const T* input,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
const int num_rows,
|
||||
int* indices,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert,
|
||||
const bool renormalize) {
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||
using AccessType = AlignedArray<T, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
T row_chunk_temp[VPT];
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
float row_chunk[VPT];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii) {
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
float row_sum_for_renormalize = 0;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0) {
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
row_sum_for_renormalize += max_val;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fuse renormalization of topk_weights into this kernel
|
||||
if (renormalize && thread_group_idx == 0) {
|
||||
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = output[idx] * row_sum_for_renormalize_inv;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants {
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(
|
||||
const T* input,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
int* indices,
|
||||
const int num_rows,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert,
|
||||
const bool renormalize,
|
||||
cudaStream_t stream) {
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
|
||||
|
||||
template <typename T>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const T* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
const bool renormalize,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB);
|
||||
break;
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output,
|
||||
const bool renormalize) // [num_tokens, num_experts]
|
||||
{
|
||||
// Check data type
|
||||
TORCH_CHECK(
|
||||
gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
|
||||
gating_output.scalar_type() == at::ScalarType::BFloat16,
|
||||
"gating_output must be float32, float16, or bfloat16");
|
||||
|
||||
// Check dimensions
|
||||
TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]");
|
||||
TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]");
|
||||
TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]");
|
||||
|
||||
// Check shapes
|
||||
TORCH_CHECK(
|
||||
gating_output.size(0) == topk_weights.size(0),
|
||||
"First dimension of topk_weights must match num_tokens in gating_output");
|
||||
TORCH_CHECK(
|
||||
gating_output.size(0) == topk_indices.size(0),
|
||||
"First dimension of topk_indices must match num_tokens in gating_output");
|
||||
TORCH_CHECK(
|
||||
topk_weights.size(-1) == topk_indices.size(-1),
|
||||
"Second dimension of topk_indices must match topk in topk_weights");
|
||||
TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts");
|
||||
|
||||
const int num_experts = static_cast<int>(gating_output.size(-1));
|
||||
const int num_tokens = static_cast<int>(gating_output.size(0));
|
||||
const int topk = static_cast<int>(topk_weights.size(-1));
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace =
|
||||
torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
|
||||
|
||||
const at::ScalarType dtype = gating_output.scalar_type();
|
||||
if (dtype == at::ScalarType::Float) {
|
||||
topkGatingSoftmaxKernelLauncher<float>(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
renormalize,
|
||||
stream);
|
||||
} else if (dtype == at::ScalarType::Half) {
|
||||
topkGatingSoftmaxKernelLauncher<__half>(
|
||||
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
renormalize,
|
||||
stream);
|
||||
} else if (dtype == at::ScalarType::BFloat16) {
|
||||
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
|
||||
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
renormalize,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
|
||||
}
|
||||
}
|
||||
471
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
Normal file
471
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
Normal file
@@ -0,0 +1,471 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename ElementAB,
|
||||
typename ElementC,
|
||||
typename ElementSF,
|
||||
typename ElementAccumulator,
|
||||
typename LayoutSFA,
|
||||
typename LayoutSFB,
|
||||
typename ScaleConfig>
|
||||
__global__ void __get_group_gemm_starts(
|
||||
ElementAB** a_offsets,
|
||||
ElementAB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementSF** a_scales_offsets,
|
||||
ElementSF** b_scales_offsets,
|
||||
ElementAccumulator** alpha_offsets,
|
||||
LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int,
|
||||
ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementSF* a_scales_base_as_int,
|
||||
ElementSF* b_scales_base_as_int,
|
||||
ElementAccumulator* alphas_base_as_int,
|
||||
const int32_t* expert_offsets,
|
||||
const int32_t* sf_offsets,
|
||||
const int32_t* problem_sizes_as_shapes,
|
||||
const int K,
|
||||
const int N) {
|
||||
int64_t expert_id = threadIdx.x;
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
// Originally int32_t but upcasting to int64_t to avoid overflow
|
||||
// during offset calculations
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
|
||||
// size for block in block scale.
|
||||
int64_t group_size = 16;
|
||||
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
|
||||
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
|
||||
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
|
||||
assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes");
|
||||
|
||||
int64_t half_k = static_cast<int64_t>(k / 2);
|
||||
int64_t group_k = static_cast<int64_t>(k / group_size);
|
||||
// Shape of A as uint8/byte = [M, K // 2]
|
||||
// Shape of B as uint8/byte = [E, N, K // 2]
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
|
||||
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
|
||||
// Shape of C = [M, N]
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
// Shape of a_scale = [sum(sf_sizes), K // group_size]
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
|
||||
|
||||
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
|
||||
|
||||
// Shape of B scale = [E, N, K // group_size]
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
|
||||
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
|
||||
// Shape of alpha = [E]
|
||||
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(
|
||||
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \
|
||||
ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
|
||||
static_cast<float**>(alpha_starts.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
|
||||
static_cast<float*>(alphas.data_ptr()), \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(sf_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(problem_sizes.data_ptr()), \
|
||||
K, \
|
||||
N); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
const torch::Tensor& a_starts,
|
||||
const torch::Tensor& b_starts,
|
||||
const torch::Tensor& out_starts,
|
||||
const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts,
|
||||
const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& alphas,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& sf_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
TORCH_CHECK(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape");
|
||||
TORCH_CHECK(
|
||||
K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
if (false) {
|
||||
}
|
||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||
// ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t,
|
||||
cutlass::float_ue4m3_t,
|
||||
torch::kBFloat16,
|
||||
cutlass::bfloat16_t,
|
||||
LayoutSFA,
|
||||
LayoutSFB,
|
||||
ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
using ElementSFType = cutlass::float_ue4m3_t;
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
// Layout definitions
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Alignment constraints
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Architecture definitions
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
|
||||
// on the tile size
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
EpilogueOperatorClass,
|
||||
typename MMA1SMConfig::MmaTileShape,
|
||||
ClusterShape,
|
||||
Shape<_128, _64>,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutC*,
|
||||
AlignmentD,
|
||||
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
MainloopOperatorClass,
|
||||
ElementA,
|
||||
LayoutA*,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB*,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA1SMConfig::MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
|
||||
|
||||
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
alpha_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
problem_sizes,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize problem_sizes_as_shapes correctly
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Set the Scheduler info
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<
|
||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementType**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(ab_strides.data_ptr()),
|
||||
static_cast<const ElementType**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(ab_strides.data_ptr()),
|
||||
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
auto& fusion_args = epilogue_args.thread;
|
||||
fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
// Input validation
|
||||
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
||||
|
||||
TORCH_CHECK(
|
||||
a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
TORCH_CHECK(
|
||||
b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have the shape (num_experts, 3)");
|
||||
TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32.");
|
||||
|
||||
int M = static_cast<int>(a.size(0));
|
||||
int N = static_cast<int>(b.size(1));
|
||||
int E = static_cast<int>(b.size(0));
|
||||
int K = static_cast<int>(2 * b.size(2));
|
||||
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, sgl-kernel must "
|
||||
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
|
||||
"12.8 or above.");
|
||||
#endif
|
||||
}
|
||||
392
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file
392
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file
@@ -0,0 +1,392 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "utils.h"
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(
|
||||
const int* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int64_t topk_length,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
|
||||
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
|
||||
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int64_t num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_blockscale_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* blockscale_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int64_t num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
int32_t tot_rounded_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
blockscale_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
int num_tokens = problem_sizes1[i * 3];
|
||||
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
|
||||
tot_offset += num_tokens;
|
||||
tot_rounded_offset += rounded_num_tokens;
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
blockscale_offsets[i + 1] = tot_rounded_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(
|
||||
const int32_t* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer,
|
||||
const int64_t topk_length,
|
||||
const int64_t topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_moe_prepare_input_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
|
||||
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
|
||||
|
||||
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
if (blockscale_offsets.has_value()) {
|
||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
} else {
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
|
||||
get_moe_prepare_input_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void shuffleRowsKernel(
|
||||
const T* input,
|
||||
const int32_t* dst2src_map,
|
||||
T* output,
|
||||
int64_t num_src_rows,
|
||||
int64_t num_dst_rows,
|
||||
int64_t num_cols) {
|
||||
int64_t dest_row_idx = blockIdx.x;
|
||||
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
||||
|
||||
if (blockIdx.x < num_dst_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
||||
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
||||
|
||||
auto const start_offset = threadIdx.x;
|
||||
auto const stride = blockDim.x;
|
||||
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
||||
|
||||
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define DECLARE_SHUFFLE_ROWS(T) \
|
||||
__global__ void shuffleRowsKernel( \
|
||||
const T* input, \
|
||||
const int32_t* dst2src_map, \
|
||||
T* output, \
|
||||
int64_t num_src_rows, \
|
||||
int64_t num_dest_rows, \
|
||||
int64_t num_cols);
|
||||
|
||||
DECLARE_SHUFFLE_ROWS(float);
|
||||
DECLARE_SHUFFLE_ROWS(half);
|
||||
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
|
||||
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
|
||||
DECLARE_SHUFFLE_ROWS(uint8_t);
|
||||
|
||||
#define SHUFFLE_ROWS(T) \
|
||||
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
|
||||
reinterpret_cast<const T*>(input), \
|
||||
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
|
||||
reinterpret_cast<T*>(output), \
|
||||
num_src_rows, \
|
||||
num_dst_rows, \
|
||||
num_cols)
|
||||
|
||||
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
|
||||
case T: \
|
||||
SHUFFLE_ROWS(CUDA_T); \
|
||||
break;
|
||||
|
||||
void shuffle_rows_caller(
|
||||
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
||||
TORCH_CHECK(
|
||||
input_tensor.scalar_type() == output_tensor.scalar_type(),
|
||||
"Input and output tensors must have the same data type");
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
|
||||
uint32_t threads = 256;
|
||||
int64_t num_dst_rows = output_tensor.size(0);
|
||||
int64_t num_src_rows = input_tensor.size(0);
|
||||
int64_t num_cols = input_tensor.size(1);
|
||||
const void* input = input_tensor.data_ptr();
|
||||
void* output = output_tensor.data_ptr();
|
||||
switch (input_tensor.scalar_type()) {
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
|
||||
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
|
||||
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
|
||||
default:
|
||||
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
||||
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void apply_shuffle_mul_sum_kernel(
|
||||
const scalar_t* __restrict__ input_tensor, // [m * topk, k] (expert-major layout)
|
||||
scalar_t* __restrict__ output_tensor, // [m, k] (token-major layout)
|
||||
const int32_t* __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx)
|
||||
int m,
|
||||
int topk,
|
||||
int row_stride,
|
||||
const scalar_t* __restrict__ factors) // [m * topk] (topk_weights, token-major layout)
|
||||
{
|
||||
int i = blockIdx.x;
|
||||
if (i >= m) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using t = float;
|
||||
using vec_t = flashinfer::vec_t<t, vec_size>;
|
||||
int thread_idx = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
|
||||
for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
|
||||
int d = d_vec_idx * vec_size;
|
||||
vec_t sum_vec;
|
||||
sum_vec.fill(0.0f);
|
||||
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int token_major_idx = i * topk + j;
|
||||
int src_row = permutation[token_major_idx];
|
||||
|
||||
vec_t val_vec;
|
||||
val_vec.cast_load(input_tensor + src_row * row_stride + d);
|
||||
|
||||
t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[token_major_idx];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < vec_size; ++k) {
|
||||
sum_vec[k] += factor * val_vec[k];
|
||||
}
|
||||
}
|
||||
sum_vec.cast_store(output_tensor + i * row_stride + d);
|
||||
}
|
||||
|
||||
// remainder part
|
||||
int remainder_start = (row_stride / vec_size) * vec_size;
|
||||
for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
|
||||
t sum_val = 0.0;
|
||||
for (int j = 0; j < topk; ++j) {
|
||||
int token_major_idx = i * topk + j;
|
||||
int src_row = permutation[token_major_idx];
|
||||
t val = input_tensor[src_row * row_stride + d];
|
||||
|
||||
t factor = 1.0;
|
||||
if (factors != nullptr) {
|
||||
factor = factors[token_major_idx];
|
||||
}
|
||||
sum_val += factor * val;
|
||||
}
|
||||
output_tensor[i * row_stride + d] = sum_val;
|
||||
}
|
||||
}
|
||||
|
||||
void get_apply_shuffle_mul_sum_caller(
|
||||
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
|
||||
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
|
||||
const torch::Tensor& permutation, // [m * topk], int32
|
||||
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
|
||||
{
|
||||
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
|
||||
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
|
||||
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");
|
||||
|
||||
int m = output_tensor.size(0);
|
||||
int topk = int(permutation.size(0) / m);
|
||||
int row_stride = output_tensor.size(1);
|
||||
|
||||
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
|
||||
|
||||
auto scalar_type = output_tensor.scalar_type();
|
||||
uint32_t vec_size = 16 / sizeof(scalar_type);
|
||||
auto blockDim = std::min(row_stride / vec_size, 1024U);
|
||||
dim3 block(blockDim);
|
||||
|
||||
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
|
||||
|
||||
const int32_t* perm_ptr = permutation.data_ptr<int32_t>();
|
||||
|
||||
void* factors_ptr = nullptr;
|
||||
if (factors_opt.has_value()) {
|
||||
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
|
||||
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
|
||||
factors_ptr = factors_opt->data_ptr();
|
||||
}
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] {
|
||||
apply_shuffle_mul_sum_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input_tensor.data_ptr()),
|
||||
static_cast<scalar_t*>(output_tensor.data_ptr()),
|
||||
perm_ptr,
|
||||
m,
|
||||
topk,
|
||||
row_stride,
|
||||
static_cast<const scalar_t*>(factors_ptr));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
|
||||
*
|
||||
* This function performs the equivalent of the following PyTorch expression:
|
||||
*
|
||||
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
*
|
||||
* Specifically:
|
||||
* - `input` is shuffled using the `permutation` tensor.
|
||||
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
|
||||
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
|
||||
*
|
||||
* @param input Input tensor of shape (m * topk, k), representing c2.
|
||||
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
|
||||
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
|
||||
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
|
||||
*/
|
||||
void apply_shuffle_mul_sum(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& permutation,
|
||||
const std::optional<torch::Tensor>& factors) {
|
||||
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
|
||||
}
|
||||
Reference in New Issue
Block a user