[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)

Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
chenxj
2025-09-02 13:17:26 +08:00
committed by GitHub
parent 21e1bc475c
commit d4a938417d
11 changed files with 291 additions and 120 deletions

View File

@@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts(
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \

View File

@@ -2,6 +2,8 @@
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <type_traits>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
@@ -9,38 +11,60 @@ using namespace cute;
namespace {
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
enum class Sched { PP, CO };
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
template <int M, int N, int K, int A, int B, int C, Sched S>
struct SM90W4A8Config {
using KernelSchedule = std::conditional_t<
S == Sched::PP,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
using EpilogueSchedule = std::conditional_t<
S == Sched::PP,
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>;
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>;
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
};
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
template <int M, int N, int K, int A, int B, int C>
using SM90_PP = SM90W4A8Config<M, N, K, A, B, C, Sched::PP>;
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
template <int M, int N, int K, int A, int B, int C>
using SM90_CO = SM90W4A8Config<M, N, K, A, B, C, Sched::CO>;
template <typename Config>
inline void invoke_gemm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
using GemmT = typename Config::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<GemmT>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
@@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90(
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1);
@@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90(
if (n == 4096 && k == 7168) {
// group gemm 1
if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 16, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -144,8 +160,7 @@ void dispatch_w4a8_moe_mm_sm90(
} else if (n == 7168 && k == 2048) {
// group gemm 2
if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_PP<64, 16, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -159,8 +174,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -174,8 +188,125 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 512 && k == 7168) {
// group gemm 1 for tp
if (m <= 4) {
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 256) {
// group gemm 2 for tp
if (m <= 8) {
invoke_gemm<SM90_PP<64, 16, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
invoke_gemm<SM90_PP<128, 32, 128, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
@@ -190,20 +321,35 @@ void dispatch_w4a8_moe_mm_sm90(
chunk_size);
}
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
if (k % 512 == 0) {
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
}
}

View File

@@ -41,9 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16)
using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations
@@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm {
static constexpr int GroupSize = 128;
static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize;
using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
@@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller(
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
@@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller(
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<const typename Gemm::ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)},
{fusion_args,

View File

@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
alignment = 4 if k % 512 == 0 else 1
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024])
@pytest.mark.parametrize("n", [1024, 2048])
@pytest.mark.parametrize("k", [256, 512, 1024])
@pytest.mark.parametrize("n", [1024, 2048, 7168])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)