[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)
Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
@@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts(
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "w4a8_grouped_mm_c3x.cuh"
|
||||
|
||||
@@ -9,38 +11,60 @@ using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
enum class Sched { PP, CO };
|
||||
|
||||
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
template <int M, int N, int K, int A, int B, int C, Sched S>
|
||||
struct SM90W4A8Config {
|
||||
using KernelSchedule = std::conditional_t<
|
||||
S == Sched::PP,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
using EpilogueSchedule = std::conditional_t<
|
||||
S == Sched::PP,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
\
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
|
||||
};
|
||||
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
|
||||
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>;
|
||||
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
|
||||
template <int M, int N, int K, int A, int B, int C>
|
||||
using SM90_PP = SM90W4A8Config<M, N, K, A, B, C, Sched::PP>;
|
||||
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
|
||||
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
|
||||
template <int M, int N, int K, int A, int B, int C>
|
||||
using SM90_CO = SM90W4A8Config<M, N, K, A, B, C, Sched::CO>;
|
||||
|
||||
template <typename Config>
|
||||
inline void invoke_gemm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size) {
|
||||
using GemmT = typename Config::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<GemmT>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
|
||||
void dispatch_w4a8_moe_mm_sm90(
|
||||
torch::Tensor& d_tensors,
|
||||
@@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk) {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
|
||||
uint32_t const m = a_tensors.size(0) / topk;
|
||||
uint32_t const n = d_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
@@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
if (n == 4096 && k == 7168) {
|
||||
// group gemm 1
|
||||
if (m <= 4) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -144,8 +160,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
} else if (n == 7168 && k == 2048) {
|
||||
// group gemm 2
|
||||
if (m <= 8) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_PP<64, 16, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -159,8 +174,7 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -174,8 +188,125 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 512 && k == 7168) {
|
||||
// group gemm 1 for tp
|
||||
if (m <= 4) {
|
||||
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 16) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 256) {
|
||||
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 1024) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
} else if (n == 7168 && k == 256) {
|
||||
// group gemm 2 for tp
|
||||
if (m <= 8) {
|
||||
invoke_gemm<SM90_PP<64, 16, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else if (m <= 512) {
|
||||
invoke_gemm<SM90_PP<128, 32, 128, 2, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
@@ -190,20 +321,35 @@ void dispatch_w4a8_moe_mm_sm90(
|
||||
chunk_size);
|
||||
}
|
||||
} else {
|
||||
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
|
||||
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
if (k % 512 == 0) {
|
||||
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
} else {
|
||||
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
|
||||
d_tensors,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,9 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||
using ElementAccumulator = float; // Accumulator type
|
||||
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
|
||||
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||
using ElementD = ElementC; // Default output type (FP16)
|
||||
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||
using ElementD = ElementC; // Default output type (FP16)
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Architecture-specific configurations
|
||||
@@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
|
||||
struct cutlass_3x_w4a8_group_gemm {
|
||||
static constexpr int GroupSize = 128;
|
||||
static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize;
|
||||
using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller(
|
||||
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
|
||||
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
|
||||
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
|
||||
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
|
||||
|
||||
// Check tensor types
|
||||
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
|
||||
@@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller(
|
||||
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<const typename Gemm::ElementScalePacked**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
|
||||
static_cast<int>(chunk_size)},
|
||||
{fusion_args,
|
||||
|
||||
@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
alignment = 4 if k % 512 == 0 else 1
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[1],
|
||||
(ref_scale.shape[2] // alignment),
|
||||
alignment,
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[2] // alignment,
|
||||
ref_scale.shape[1] * alignment,
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("k", [512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048])
|
||||
@pytest.mark.parametrize("k", [256, 512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048, 7168])
|
||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user