[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
This commit is contained in:
1
sgl-kernel/CMakeLists.txt
Executable file → Normal file
1
sgl-kernel/CMakeLists.txt
Executable file → Normal file
@@ -207,6 +207,7 @@ set(SOURCES
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
"csrc/moe/fp8_blockwise_moe_kernel.cu"
|
||||
"csrc/moe/prepare_moe_input.cu"
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
|
||||
@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
|
||||
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
|
||||
"expert_offsets) -> ()");
|
||||
"expert_offsets, Tensor workspace) -> ()");
|
||||
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
|
||||
|
||||
m.def(
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
|
||||
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
|
||||
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
// Layout definitions
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignment constraints
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
// Architecture definitions
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
// For fp8 block scale.
|
||||
// using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
|
||||
// ScaleGranularityK, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA =
|
||||
// decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = 1;
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
// sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
|
||||
hw_info.sm_count = 148;
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{},
|
||||
nullptr,
|
||||
@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
// Initialize problem_sizes_as_shapes correctly
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
// Use prob_shape in the GEMM arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run();
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets) {
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
// Check the first matrix size to decide on the configuration
|
||||
// Assuming all matrices in the group have similar size characteristics
|
||||
// bool use_small_config = a[0].size(0) <= 128;
|
||||
struct MMALargeConfig {
|
||||
struct MmaConfig1 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
struct MmaConfig2 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MMASmallConfig {
|
||||
struct MmaConfig3 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _16, _128>;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
torch::Tensor workspace = torch::empty(100, options_int);
|
||||
torch::Tensor output_t = output.t();
|
||||
torch::Tensor a_t = a.t();
|
||||
torch::Tensor b_t = b.transpose(1, 2);
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
if (a.size(0) <= 512) {
|
||||
run_get_group_gemm_starts<MMASmallConfig::LayoutSFA, MMASmallConfig::LayoutSFB, MMASmallConfig::ScaleConfig>(
|
||||
if (a.size(0) <= 512 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
problem_sizes,
|
||||
problem_sizes_transpose,
|
||||
true);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMASmallConfig, cutlass::layout::ColumnMajor>(
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -244,8 +247,8 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else {
|
||||
run_get_group_gemm_starts<MMALargeConfig::LayoutSFA, MMALargeConfig::LayoutSFB, MMALargeConfig::ScaleConfig>(
|
||||
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -261,7 +264,38 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMALargeConfig, cutlass::layout::RowMajor>(
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig2, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
run_get_group_gemm_starts<MmaConfig3::LayoutSFA, MmaConfig3::LayoutSFB, MmaConfig3::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig3, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
*/
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets) {
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(
|
||||
@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor");
|
||||
TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor");
|
||||
TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor");
|
||||
TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor");
|
||||
TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor");
|
||||
TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
|
||||
|
||||
bool can_implement = false;
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets);
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets);
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
can_implement = true;
|
||||
}
|
||||
|
||||
128
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file
128
sgl-kernel/csrc/moe/prepare_moe_input.cu
Normal file
@@ -0,0 +1,128 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(
|
||||
const int* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(
|
||||
const int* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_moe_prepare_input_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
|
||||
get_moe_prepare_input_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k);
|
||||
return;
|
||||
}
|
||||
@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets);
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
|
||||
@@ -47,6 +47,7 @@ from sgl_kernel.moe import (
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
moe_fused_gate,
|
||||
prepare_moe_input,
|
||||
topk_softmax,
|
||||
)
|
||||
from sgl_kernel.sampling import (
|
||||
|
||||
@@ -64,6 +64,11 @@ def moe_fused_gate(
|
||||
|
||||
def fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
):
|
||||
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
)
|
||||
|
||||
|
||||
def prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
c_strides = torch.full(
|
||||
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c_out,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_stack,
|
||||
b_stack,
|
||||
a_scale_stack,
|
||||
@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
|
||||
Reference in New Issue
Block a user