Files
sglang/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh
Qi Yuhang 9a30914e94 [sgl-kernel][1/N]Support Expert Specialization Grouped GEMM (#11432)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
2025-10-12 20:19:21 -07:00

269 lines
9.1 KiB
Plaintext

#pragma once
#include <cuda.h>
#include <iostream>
#include "cute/tensor.hpp"
#include "es_fp8_blockwise_traits.cuh"
namespace expert_specialization {
using namespace cute;
template <typename ElementAB, typename ElementSF, typename ElementD>
struct Fp8BlockwiseGroupedGemmOffsetFunctor {
// Input
int* expert_offsets{nullptr};
// Base pointers
ElementAB* a_base{nullptr};
ElementAB* b_base{nullptr};
ElementD* out_base{nullptr};
ElementSF* a_scales_base{nullptr};
ElementSF* b_scales_base{nullptr};
// Output
// Pointer Array for A/B
ElementAB** a_offsets{nullptr};
ElementAB** b_offsets{nullptr};
ElementSF** a_scales_offsets{nullptr};
ElementSF** b_scales_offsets{nullptr};
ElementD** out_offsets{nullptr};
Fp8BlockwiseGroupedGemmOffsetFunctor() = default;
Fp8BlockwiseGroupedGemmOffsetFunctor(
int* _expert_offsets,
ElementAB* _a_base,
ElementAB* _b_base,
ElementD* _out_base,
ElementSF* _a_scales_base,
ElementSF* _b_scales_base,
ElementAB** _a_offsets,
ElementAB** _b_offsets,
ElementSF** _a_scales_offsets,
ElementSF** _b_scales_offsets,
ElementD** _out_offsets)
: expert_offsets(_expert_offsets),
a_base(_a_base),
b_base(_b_base),
out_base(_out_base),
a_scales_base(_a_scales_base),
b_scales_base(_b_scales_base),
a_offsets(_a_offsets),
b_offsets(_b_offsets),
a_scales_offsets(_a_scales_offsets),
b_scales_offsets(_b_scales_offsets),
out_offsets(_out_offsets) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t a_stride = 0;
int64_t b_stride = 0;
int64_t a_scale_stride = 0;
int64_t b_scale_stride = 0;
a_stride = expert_offset * k;
b_stride = expert_id * k * n;
a_scale_stride = expert_offset * k / 128;
b_scale_stride = expert_id * k * n / 128 / 128;
a_offsets[expert_id] = a_base + a_stride;
b_offsets[expert_id] = b_base + b_stride;
a_scales_offsets[expert_id] = a_scales_base + a_scale_stride;
b_scales_offsets[expert_id] = b_scales_base + b_scale_stride;
out_offsets[expert_id] = out_base + expert_offset * n;
}
};
template <typename PerfConfig>
struct Fp8BlockwiseGroupedGemmSFLayoutFunctor {
using ScaleConfig = typename PerfConfig::ScaleConfig;
using LayoutSFA = typename PerfConfig::LayoutSFA;
using LayoutSFB = typename PerfConfig::LayoutSFB;
LayoutSFA* layout_sfa_base{nullptr};
LayoutSFB* layout_sfb_base{nullptr};
Fp8BlockwiseGroupedGemmSFLayoutFunctor() = default;
Fp8BlockwiseGroupedGemmSFLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
}
};
// [Unused]: Specialization for Swap A/B
template <>
struct Fp8BlockwiseGroupedGemmSFLayoutFunctor<PerfConfigLowMH20> {
using ScaleConfig = typename PerfConfigLowMH20::ScaleConfig;
using LayoutSFA = typename PerfConfigLowMH20::LayoutSFA;
using LayoutSFB = typename PerfConfigLowMH20::LayoutSFB;
LayoutSFA* layout_sfa_base{nullptr};
LayoutSFB* layout_sfb_base{nullptr};
Fp8BlockwiseGroupedGemmSFLayoutFunctor() = default;
Fp8BlockwiseGroupedGemmSFLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(n, m, k, 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(n, m, k, 1));
}
};
template <typename PerfConfig>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor;
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m <= 48) {
// Swap A/B
problem_sizes[expert_id * 3 + 0] = n;
problem_sizes[expert_id * 3 + 1] = m;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m <= 32) {
// Swap A/B
problem_sizes[expert_id * 3 + 0] = n;
problem_sizes[expert_id * 3 + 1] = m;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 48 && m <= 96) {
problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 32 && m <= 64) {
problem_sizes[expert_id * 3 + 0] = n;
problem_sizes[expert_id * 3 + 1] = m;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 96) {
problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <>
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> {
int* problem_sizes{nullptr};
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default;
Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
if (m > 64) {
problem_sizes[expert_id * 3 + 0] = m;
problem_sizes[expert_id * 3 + 1] = n;
problem_sizes[expert_id * 3 + 2] = k;
} else {
problem_sizes[expert_id * 3 + 0] = 0;
problem_sizes[expert_id * 3 + 1] = 0;
problem_sizes[expert_id * 3 + 2] = 0;
}
}
};
template <
typename OffsetFunctor,
typename ScaleLayoutFunctor,
typename LowMProblemSizeFilterFunctor,
typename MiddleMProblemSizeFilterFunctor,
typename HighMProblemSizeFilterFunctor>
__global__ void groupedGemmPreComputeKernel(
int* problem_sizes,
OffsetFunctor offset_functor,
ScaleLayoutFunctor sf_functor,
LowMProblemSizeFilterFunctor lm_psf_functor,
MiddleMProblemSizeFilterFunctor mm_psf_functor,
HighMProblemSizeFilterFunctor hm_psf_functor) {
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
int m = problem_sizes[expert_id * 3 + 0];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];
offset_functor(expert_id, m, n, k);
sf_functor(expert_id, m, n, k);
lm_psf_functor(expert_id, m, n, k);
mm_psf_functor(expert_id, m, n, k);
hm_psf_functor(expert_id, m, n, k);
}
} // namespace expert_specialization