Files
sglang/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh

285 lines
10 KiB
Plaintext

#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <cassert>
#include <iostream>
#include <string>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "es_fp8_blockwise_functor.cuh"
namespace expert_specialization {
using namespace cute;
template <typename T>
void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
// Output
torch::Tensor& out_ptrs,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor& layout_sfa,
torch::Tensor& layout_sfb,
torch::Tensor& lm_problem_sizes,
torch::Tensor& mm_problem_sizes,
torch::Tensor& hm_problem_sizes,
// Input
torch::Tensor& out_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& problem_sizes,
torch::Tensor const& expert_offsets,
bool is_h20_device,
cudaStream_t stream) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
// Creat Scale Factor Layout Functor
using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA;
using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB;
struct Fp8BlockwiseGroupedGemmSFLayoutFunctor<PerfConfigMiddleMH20> sf_layout(
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
int num_experts = (int)expert_offsets.size(0);
TORCH_CHECK(num_experts <= 1024, "Expert more than 1024"); // Max threads per block is 1024
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, T> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<T*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<T**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
}
template <typename GemmTraits>
void launch_sm90_fp8_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs,
const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs,
const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& workspace,
cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA;
using ElementB = typename GemmTraits::ElementB;
using StrideB = typename GemmTraits::StrideB;
using ElementAccumulator = typename GemmTraits::ElementAccumulator;
using LayoutSFA = typename GemmTraits::LayoutSFA;
using LayoutSFB = typename GemmTraits::LayoutSFB;
using ElementD = typename GemmTraits::ElementD;
using StrideD = typename GemmTraits::StrideD;
using UnderlyingProblemShape = typename GemmTraits::ProblemShape::UnderlyingProblemShape;
using Gemm = typename GemmTraits::Gemm;
using GemmKernel = typename GemmTraits::GemmKernel;
int num_experts = (int)problem_sizes.size(0);
Gemm gemm_op;
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
typename GemmKernel::EpilogueArguments epilogue_args{
{}, nullptr, nullptr, static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<StrideD*>(stride_d.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream, nullptr, true); // Enable PDL
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
torch::Tensor& out_ptrs,
const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs,
const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs,
const torch::Tensor& stride_a,
const torch::Tensor& stride_b,
const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes,
const torch::Tensor& workspace,
bool is_h20_device,
cudaStream_t stream) {
using LowMGemmH20Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>;
using LowMGemmHx00Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMHx00>;
using MiddleMGemmH20Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigMiddleMH20>;
using MiddleMGemmHx00Traits = ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<
OutType,
cutlass::layout::ColumnMajor,
PerfConfigMiddleMHx00>;
using HighMGemmH20Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMH20>;
using HighMGemmHx00Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;
if (!is_h20_device) {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>(
out_ptrs,
b_ptrs,
a_ptrs,
b_scales_ptrs,
a_scales_ptrs,
stride_b,
stride_a,
stride_d,
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
out_ptrs,
b_ptrs,
a_ptrs,
b_scales_ptrs,
a_scales_ptrs,
stride_b,
stride_a,
stride_d,
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
}
if (!is_h20_device) {
launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmHx00Traits>(
out_ptrs,
b_ptrs,
a_ptrs,
b_scales_ptrs,
a_scales_ptrs,
stride_b,
stride_a,
stride_d,
layout_sfb,
layout_sfa,
mm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_d,
layout_sfa,
layout_sfb,
mm_problem_sizes,
workspace,
stream);
}
if (!is_h20_device) {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_d,
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_d,
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
}
}
} // namespace expert_specialization