[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com> Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
91
sgl-kernel/csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
cutlass_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
return;
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
get_cutlass_w4a8_moe_mm_data_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k);
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void int4_fp8_get_group_gemm_starts(
|
||||
int32_t* expert_offsets,
|
||||
ElementA** a_offsets,
|
||||
ElementB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementAccumulator** a_scales_offsets,
|
||||
cutlass::bfloat16_t** b_scales_offsets,
|
||||
ElementA* a_base_as_int,
|
||||
ElementB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
cutlass::bfloat16_t* b_scales_base_as_int,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
|
||||
out_tensors.size(1), \
|
||||
a_tensors.size(1), \
|
||||
per_act_token, \
|
||||
per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_int4_fp8_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
240
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
Normal file
@@ -0,0 +1,240 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "w4a8_grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
|
||||
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
|
||||
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
|
||||
|
||||
void dispatch_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
|
||||
uint32_t const m = a_tensors.size(0) / topk;
|
||||
uint32_t const n = d_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n == 4096 && k == 7168) {
|
||||
// group gemm 1
|
||||
if (m <= 4) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 7168 && k == 2048) {
|
||||
// group gemm 2
|
||||
if (m <= 8) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
dispatch_w4a8_moe_mm_sm90(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk);
|
||||
}
|
||||
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
276
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
Normal file
@@ -0,0 +1,276 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @file w4a8_grouped_mm_c3x.cuh
|
||||
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
|
||||
* precision
|
||||
*
|
||||
* This file implements a grouped GEMM operation that multiplies FP8 matrices
|
||||
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
|
||||
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
|
||||
* Cores for mixed precision arithmetic.
|
||||
*
|
||||
* Key features:
|
||||
* - Supports grouped GEMM operations with multiple experts
|
||||
* - Uses FP8 (e4m3) for matrix A
|
||||
* - Uses INT4 quantization for matrix B with per-block scaling
|
||||
* - Implements preprocessing for INT4 encoding and scale packing
|
||||
* - Optimized for Hopper architecture with Tensor Core operations
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
|
||||
#include "w4a8_get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
// Type definitions
|
||||
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||
using ElementAccumulator = float; // Accumulator type
|
||||
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
|
||||
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||
using ElementD = ElementC; // Default output type (FP16)
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Architecture-specific configurations
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
// constexpr int TileShapeK = 512;
|
||||
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
|
||||
// using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
// Layout configurations
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Transposed layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
struct cutlass_3x_w4a8_group_gemm {
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
LayoutC_Transpose*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD_Transpose*,
|
||||
AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
cute::tuple<QuantType, ElementScalePacked>,
|
||||
LayoutB_Transpose*,
|
||||
AlignmentB,
|
||||
MmaType,
|
||||
LayoutA_Transpose*,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
// Define the final kernel and GEMM operation types
|
||||
using GemmKernelScaleOnly =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
|
||||
*
|
||||
* This function performs multiple GEMM operations in parallel where each
|
||||
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
|
||||
* applying per-channel scaling factors. It's designed for efficient execution
|
||||
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
|
||||
* mixed precision arithmetic.
|
||||
*
|
||||
* The function includes preprocessing steps for both INT4 tensors and scale
|
||||
* factors to ensure optimal performance and correct operation.
|
||||
*
|
||||
* @param d_tensors Output tensor D with shape [total_m, total_n]
|
||||
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
|
||||
* [total_m, K]
|
||||
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
|
||||
* shape [E, N, K/2]
|
||||
* @param a_scales Tensor containing A matrix scale factors
|
||||
* @param b_scales Tensor containing B matrix scale factors with shape [E,
|
||||
* K//512, N*4]
|
||||
* @param expert_offsets Tensor containing expert offsets for determining group
|
||||
* boundaries (int32)
|
||||
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
|
||||
* 3] (M, N, K for each group) (int32)
|
||||
* @param a_strides Stride information for A tensors
|
||||
* @param b_strides Stride information for B tensors
|
||||
* @param d_strides Stride information for D tensors
|
||||
* @param s_strides Stride information for scale tensors
|
||||
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
|
||||
*/
|
||||
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
template <typename Gemm>
|
||||
void cutlass_w4a8_group_gemm_caller(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size) {
|
||||
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
using Args = typename Gemm::GemmScaleOnly::Arguments;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
// Check inputs
|
||||
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
|
||||
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
|
||||
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
|
||||
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
|
||||
// Check tensor shapes
|
||||
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
|
||||
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
|
||||
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
|
||||
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
|
||||
|
||||
// Check tensor types
|
||||
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
|
||||
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
|
||||
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_tensors.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha = 1.0f;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||
;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
run_int4_fp8_get_group_gemm_starts(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
d_tensors,
|
||||
a_scales,
|
||||
b_scales);
|
||||
|
||||
arguments = Args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
|
||||
static_cast<int>(chunk_size)},
|
||||
{fusion_args,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
|
||||
hw_info};
|
||||
|
||||
// Instantiate and run GEMM
|
||||
typename Gemm::GemmScaleOnly gemm;
|
||||
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM implementation not supported");
|
||||
}
|
||||
|
||||
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM initialization failed");
|
||||
}
|
||||
|
||||
status = gemm.run(stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
TORCH_CHECK(false, "GEMM execution failed");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
79
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu
Normal file
@@ -0,0 +1,79 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes_w4a8(
|
||||
const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets_w4a8(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3 + 1];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_w4a8_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
}
|
||||
Reference in New Issue
Block a user