[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)
This commit is contained in:
@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets);
|
||||
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
at::cuda::CUDAGuard device_guard{(char)a.get_device()};
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
if (output.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else if (output.dtype() == torch::kFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
|
||||
if (output.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
|
||||
out_ptrs,
|
||||
@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else if (output.dtype() == torch::kFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
|
||||
out_ptrs,
|
||||
@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
@@ -14,6 +15,7 @@ namespace expert_specialization {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
|
||||
// Output
|
||||
torch::Tensor& out_ptrs,
|
||||
@@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& expert_offsets) {
|
||||
torch::Tensor const& expert_offsets,
|
||||
bool is_h20_device,
|
||||
cudaStream_t stream) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
|
||||
// Creat Scale Factor Layout Functor
|
||||
using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA;
|
||||
using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB;
|
||||
@@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
// Dispatch
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::bfloat16_t> of(
|
||||
static_cast<int*>(expert_offsets.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
|
||||
static_cast<cutlass::bfloat16_t*>(out_tensors.data_ptr()),
|
||||
static_cast<float*>(a_scales.data_ptr()),
|
||||
static_cast<float*>(b_scales.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<cutlass::bfloat16_t**>(out_ptrs.data_ptr()));
|
||||
if (!is_h20_device) {
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
} else {
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
}
|
||||
} else if (out_tensors.dtype() == torch::kFloat16) {
|
||||
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of(
|
||||
static_cast<int*>(expert_offsets.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
|
||||
static_cast<cutlass::half_t*>(out_tensors.data_ptr()),
|
||||
static_cast<float*>(a_scales.data_ptr()),
|
||||
static_cast<float*>(b_scales.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<cutlass::half_t**>(out_ptrs.data_ptr()));
|
||||
if (!is_h20_device) {
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
} else {
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
}
|
||||
TORCH_CHECK(num_experts <= 1024, "Expert more than 1024"); // Max threads per block is 1024
|
||||
|
||||
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, T> of(
|
||||
static_cast<int*>(expert_offsets.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
|
||||
static_cast<T*>(out_tensors.data_ptr()),
|
||||
static_cast<float*>(a_scales.data_ptr()),
|
||||
static_cast<float*>(b_scales.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()),
|
||||
static_cast<T**>(out_ptrs.data_ptr()));
|
||||
if (!is_h20_device) {
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
|
||||
static_cast<int*>(lm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
|
||||
static_cast<int*>(mm_problem_sizes.data_ptr()));
|
||||
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
|
||||
static_cast<int*>(hm_problem_sizes.data_ptr()));
|
||||
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||
const torch::Tensor& stride_d,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes) {
|
||||
const torch::Tensor& problem_sizes,
|
||||
cudaStream_t stream) {
|
||||
using ElementA = typename GemmTraits::ElementA;
|
||||
using StrideA = typename GemmTraits::StrideA;
|
||||
using ElementB = typename GemmTraits::ElementB;
|
||||
@@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
@@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& lm_problem_sizes,
|
||||
const torch::Tensor& mm_problem_sizes,
|
||||
const torch::Tensor& hm_problem_sizes) {
|
||||
const torch::Tensor& hm_problem_sizes,
|
||||
bool is_h20_device,
|
||||
cudaStream_t stream) {
|
||||
using LowMGemmH20Traits =
|
||||
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>;
|
||||
using LowMGemmHx00Traits =
|
||||
@@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
using HighMGemmHx00Traits =
|
||||
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;
|
||||
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
|
||||
if (!is_h20_device) {
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>(
|
||||
out_ptrs,
|
||||
@@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfb,
|
||||
layout_sfa,
|
||||
lm_problem_sizes);
|
||||
lm_problem_sizes,
|
||||
stream);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
|
||||
out_ptrs,
|
||||
@@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfb,
|
||||
layout_sfa,
|
||||
lm_problem_sizes);
|
||||
lm_problem_sizes,
|
||||
stream);
|
||||
}
|
||||
|
||||
if (!is_h20_device) {
|
||||
@@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfb,
|
||||
layout_sfa,
|
||||
mm_problem_sizes);
|
||||
mm_problem_sizes,
|
||||
stream);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
|
||||
out_ptrs,
|
||||
@@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
mm_problem_sizes);
|
||||
mm_problem_sizes,
|
||||
stream);
|
||||
}
|
||||
|
||||
if (!is_h20_device) {
|
||||
@@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
stream);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
|
||||
out_ptrs,
|
||||
@@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
||||
stride_d,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm
|
||||
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
|
||||
Reference in New Issue
Block a user