Sync from v0.13
This commit is contained in:
104
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
Normal file
104
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
Normal file
@@ -0,0 +1,104 @@
|
||||
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
// ElementB is int32 (packed int4)
|
||||
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
|
||||
template <typename ElementA, typename ElementB, typename ElementC,
|
||||
typename ElementAccumulator, typename ElementGroupScale>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets,
|
||||
ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
|
||||
ElementB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int,
|
||||
ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
|
||||
int64_t scale_k) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
int64_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
// same as w8a8
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
|
||||
|
||||
// w4a8 specific
|
||||
constexpr int pack_factor = 8; // pack 8 int4 into int32
|
||||
b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
|
||||
b_group_scales_offsets[expert_id] =
|
||||
b_group_scales_base_as_int + (expert_id * scale_k * n);
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
||||
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<int32_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
|
||||
b_group_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<int32_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
|
||||
b_group_scales.data_ptr()), \
|
||||
n, k, scale_k); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
|
||||
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_group_scales.dtype() ==
|
||||
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
|
||||
TORCH_CHECK(out_tensors.dtype() ==
|
||||
torch::kBFloat16); // only support bf16 for now
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
// logical k, n
|
||||
int64_t n = out_tensors.size(1);
|
||||
int64_t k = a_tensors.size(1);
|
||||
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
483
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
Normal file
483
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
Normal file
@@ -0,0 +1,483 @@
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
// vllm includes
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "get_group_starts.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
|
||||
namespace vllm::cutlass_w4a8_moe {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Static configuration shared across all instantiations
|
||||
// -------------------------------------------------------------------------------------
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per
|
||||
// group
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
static int constexpr PackFactor = 8; // 8 int4 packed into int32
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA =
|
||||
128 /
|
||||
cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of
|
||||
// elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input
|
||||
// layouts
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||
using StrideA =
|
||||
cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB =
|
||||
cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in
|
||||
// contiguous locations in global memory. It specifies the reordering within a
|
||||
// single warp's fragment
|
||||
using LayoutAtomQuant =
|
||||
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||
LayoutAtomQuant{}, Layout<Shape<int, int, Int<1>>, StrideB>{}));
|
||||
|
||||
using ElementScale = cutlass::float_e4m3_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC =
|
||||
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC =
|
||||
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using StageCountType =
|
||||
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
|
||||
// on the tile size
|
||||
|
||||
// per-channel and per-token scales for epilogue
|
||||
using ElementSChannel = float;
|
||||
|
||||
template <class TileShape_MN, class ClusterShape_MNK, class KernelSchedule,
|
||||
class EpilogueSchedule>
|
||||
struct W4A8GroupedGemmKernel {
|
||||
using TileShape =
|
||||
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
|
||||
// per-channel, per-token scales epilogue
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogueArray<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementSChannel, ElementC,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type*, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*,
|
||||
AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT
|
||||
// WITH SCALES
|
||||
// ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled.
|
||||
// In this example, B is scaled so we make a tuple of B's information and the
|
||||
// scale information.
|
||||
using CollectiveMainloopShuffled =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
|
||||
LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||
|
||||
using GemmShuffled =
|
||||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelShuffled::InternalStrideC;
|
||||
using StrideD = typename GemmKernelShuffled::InternalStrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// static asserts for passing in strides/layouts
|
||||
// pack to 2x int64
|
||||
static_assert(sizeof(StrideS) == 2 * sizeof(int64_t));
|
||||
// pack to 3xint32,
|
||||
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
||||
"LayoutB_Reordered size must be divisible by 4 bytes");
|
||||
|
||||
static void grouped_mm(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides) {
|
||||
auto device = a_tensors.device();
|
||||
auto device_id = device.index();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device_id);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(device);
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
// get the correct offsets to pass to gemm
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs,
|
||||
a_tensors, b_tensors, out_tensors, a_scales,
|
||||
b_scales, b_group_scales, b_group_size);
|
||||
|
||||
// construct args
|
||||
using Args = typename GemmShuffled::Arguments;
|
||||
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||
Args arguments;
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
problem_sizes_torch.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
// SwapAB so B operands come first
|
||||
MainloopArguments mainloop_arguments{
|
||||
static_cast<const QuantType**>(b_ptrs.data_ptr()),
|
||||
static_cast<LayoutB_Reordered*>(b_strides.data_ptr()),
|
||||
static_cast<const MmaType**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const cutlass::Array<ElementScale, 8>**>(
|
||||
b_group_scales_ptrs.data_ptr()),
|
||||
static_cast<StrideS*>(group_scale_strides.data_ptr()),
|
||||
static_cast<int>(b_group_size)};
|
||||
|
||||
EpilogueArguments epilogue_arguments{
|
||||
// since we are doing SwapAB the channel scales comes first, then token
|
||||
// scales
|
||||
ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray
|
||||
static_cast<const ElementAccumulator**>(
|
||||
b_scales_ptrs.data_ptr()), // per-channel
|
||||
static_cast<const ElementAccumulator**>(
|
||||
a_scales_ptrs.data_ptr()), // per-token
|
||||
true, true),
|
||||
nullptr, // C
|
||||
static_cast<StrideC*>(c_strides.data_ptr()), // C
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()), // D
|
||||
static_cast<StrideC*>(c_strides.data_ptr()) // D
|
||||
};
|
||||
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id,
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape,
|
||||
mainloop_arguments, epilogue_arguments, hw_info};
|
||||
|
||||
// Allocate workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
CUTLASS_CHECK(gemm.run(stream));
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel instantiations and dispatch logic
|
||||
// ----------------------------------------------------------------------------
|
||||
using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
|
||||
// Kernel_TileShape_ClusterShape_Schedule
|
||||
using Kernel_128x16_1x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||
using Kernel_128x16_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_128, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
using Kernel_256x16_1x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||
using Kernel_256x16_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _16>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
using Kernel_256x32_1x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||
using Kernel_256x32_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _32>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
using Kernel_256x64_1x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||
using Kernel_256x64_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _64>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
using Kernel_256x128_1x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>, Coop, CoopEpi>;
|
||||
using Kernel_256x128_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_256, _128>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
using Kernel_128x256_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
void mm_dispatch(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides, const std::string& schedule) {
|
||||
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
||||
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_128x16_2x1x1_Coop") {
|
||||
Kernel_128x16_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x16_1x1x1_Coop") {
|
||||
Kernel_256x16_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x16_2x1x1_Coop") {
|
||||
Kernel_256x16_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x32_1x1x1_Coop") {
|
||||
Kernel_256x32_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x32_2x1x1_Coop") {
|
||||
Kernel_256x32_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x64_1x1x1_Coop") {
|
||||
Kernel_256x64_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x64_2x1x1_Coop") {
|
||||
Kernel_256x64_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x128_1x1x1_Coop") {
|
||||
Kernel_256x128_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_256x128_2x1x1_Coop") {
|
||||
Kernel_256x128_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else if (schedule == "Kernel_128x256_2x1x1_Coop") {
|
||||
Kernel_128x256_2x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||
}
|
||||
}
|
||||
|
||||
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// user has specified a schedule
|
||||
if (maybe_schedule) {
|
||||
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
b_group_scales, b_group_size, expert_offsets, problem_sizes,
|
||||
a_strides, b_strides, c_strides, group_scale_strides,
|
||||
*maybe_schedule);
|
||||
return;
|
||||
}
|
||||
|
||||
// use heuristic
|
||||
int m_full = a_tensors.size(0);
|
||||
int n = b_tensors.size(1);
|
||||
int k = b_tensors.size(2) * PackFactor; // logical k
|
||||
int num_experts = b_tensors.size(0);
|
||||
// per-expert batch size assuming uniform distribution
|
||||
int m_expert = m_full / num_experts;
|
||||
|
||||
std::string schedule;
|
||||
if (m_expert <= 16) {
|
||||
schedule = "Kernel_128x16_2x1x1_Coop";
|
||||
} else if (m_expert <= 32) {
|
||||
schedule = "Kernel_256x32_1x1x1_Coop";
|
||||
} else if (m_expert <= 64) {
|
||||
schedule = "Kernel_256x64_1x1x1_Coop";
|
||||
} else if (m_expert <= 128) {
|
||||
schedule = "Kernel_256x128_2x1x1_Coop";
|
||||
} else { // m_expert > 128
|
||||
schedule = "Kernel_128x256_2x1x1_Coop";
|
||||
}
|
||||
|
||||
mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
b_group_scales, b_group_size, expert_offsets, problem_sizes,
|
||||
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
torch::Tensor const& b_tensors) {
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||
TORCH_CHECK(b_tensors.is_contiguous());
|
||||
TORCH_CHECK(b_tensors.is_cuda());
|
||||
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
||||
|
||||
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
||||
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
||||
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||
|
||||
// we will store the layout to an int32 tensor;
|
||||
// this is the number of elements we need per layout
|
||||
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
||||
|
||||
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
|
||||
int num_experts = static_cast<int>(b_tensors.size(0));
|
||||
|
||||
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
||||
auto b_packed_ptr = static_cast<QuantType*>(b_tensors_packed.data_ptr());
|
||||
|
||||
// multiply by ull so result does not overflow int32
|
||||
size_t num_int4_elems = 1ull * num_experts * n * k;
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
||||
num_int4_elems);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
|
||||
// construct the layout once; assumes each expert has the same layout
|
||||
using LayoutType = LayoutB_Reordered;
|
||||
std::vector<LayoutType> layout_B_reordered_host(num_experts);
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}});
|
||||
auto shape_B = cute::make_shape(n, k, Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B);
|
||||
LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
// reorder weights for each expert
|
||||
for (int i = 0; i < num_experts; i++) {
|
||||
// since the storage type of int4b is 1 byte but one element is 4 bits
|
||||
// we need to adjust the offset
|
||||
int64_t offset =
|
||||
1ull * i * n * k * cutlass::sizeof_bits<QuantType>::value / 8;
|
||||
cutlass::reorder_tensor(b_packed_ptr + offset, layout_B,
|
||||
layout_B_reordered);
|
||||
}
|
||||
|
||||
// save the packed layout to torch tensor so we can re-use it
|
||||
auto cpu_opts =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||
torch::Tensor layout_cpu =
|
||||
torch::empty({num_experts, layout_width}, cpu_opts);
|
||||
|
||||
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
||||
&layout_B_reordered, // src (LayoutType*)
|
||||
sizeof(LayoutType)); // number of bytes
|
||||
}
|
||||
|
||||
torch::Tensor packed_layout =
|
||||
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
|
||||
|
||||
return {b_tensors_packed, packed_layout};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_moe_mm", &mm);
|
||||
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8_moe
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
430
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
430
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
@@ -0,0 +1,430 @@
|
||||
//
|
||||
// Based off of:
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <limits>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace vllm::cutlass_w4a8 {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Static configuration shared across all instantiations
|
||||
// -------------------------------------------------------------------------------------
|
||||
using MmaType = cutlass::float_e4m3_t; // A/scale element type
|
||||
using QuantType = cutlass::int4b_t; // B element type (packed int4)
|
||||
|
||||
static int constexpr TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
static int constexpr ScalePackSize = 8; // pack 8 scale elements together
|
||||
static int constexpr PackFactor = 8; // 8 4-bit packed into int32
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reordered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in
|
||||
// contiguous locations in global memory. It specifies the reordering within a
|
||||
// single warp's fragment
|
||||
using LayoutAtomQuant =
|
||||
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||
LayoutAtomQuant{}, Layout<Shape<int, int, int>, StrideB>{}));
|
||||
|
||||
// Group-wise scales
|
||||
using ElementScale = MmaType;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// Per-tok, per-chan scales
|
||||
using ElementSChannel = float;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC =
|
||||
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC =
|
||||
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
// based on the default
|
||||
// setting in the
|
||||
// Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel template — Tile/Cluster shapes
|
||||
// ----------------------------------------------------------------------------
|
||||
template <class TileShape_MN, class ClusterShape_MNK>
|
||||
struct W4A8GemmKernel {
|
||||
using TileShape =
|
||||
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
|
||||
// Epilogue per-tok, per-chan scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C
|
||||
// matrix. We can enable this if beta == 0 by changing ElementC to
|
||||
// void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type,
|
||||
AlignmentC, ElementD,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule, // This is the only epi supporting the required
|
||||
// swap + transpose.
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
// The Scale information must get paired with the operand that will be scaled.
|
||||
// In this example, B is scaled so we make a tuple of B's information and the
|
||||
// scale information.
|
||||
using CollectiveMainloopShuffled =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, ScalePackSize>>,
|
||||
LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||
using GemmShuffled =
|
||||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelShuffled::StrideC;
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// safely cast group_size to int
|
||||
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||
"group_size out of supported range for int: ", group_size);
|
||||
int const group_size_int = static_cast<int>(group_size);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data_ptr());
|
||||
// can we avoid hardcode the 8 here
|
||||
auto S_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>(
|
||||
group_scales.const_data_ptr());
|
||||
|
||||
// runtime layout for B
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
// strides
|
||||
int const scale_k = cutlass::ceil_div(k, group_size_int);
|
||||
StrideA stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
// Reverse stride here due to swap and transpose
|
||||
StrideD stride_D =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(
|
||||
StrideS{}, cute::make_shape(n, scale_k, 1));
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an
|
||||
// instance of Gemm auto arguments =
|
||||
// args_from_options<GemmShuffled>(options);
|
||||
/// Populates a Gemm::Arguments structure from the given arguments
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
using Args = typename GemmShuffled::Arguments;
|
||||
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||
|
||||
MainloopArguments mainloop_arguments{
|
||||
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
||||
S_ptr, stride_S, group_size_int};
|
||||
|
||||
EpilogueArguments epilogue_arguments{
|
||||
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
||||
nullptr,
|
||||
{}, // no C
|
||||
D_ptr,
|
||||
stride_D};
|
||||
|
||||
Args arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{n, m, k, 1}, // shape
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
CUTLASS_CHECK(gemm.run(stream));
|
||||
|
||||
return D;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel instantiations and dispatch logic
|
||||
// ----------------------------------------------------------------------------
|
||||
using Kernel_256x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x64_1x1x1 = W4A8GemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x32_1x1x1 = W4A8GemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x16_1x1x1 = W4A8GemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x256_2x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>>;
|
||||
using Kernel_128x256_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x64_1x1x1") {
|
||||
return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x32_1x1x1") {
|
||||
return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x16_1x1x1") {
|
||||
return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_2x1x1") {
|
||||
return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_1x1x1") {
|
||||
return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x128_1x1x1") {
|
||||
return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x64_1x1x1") {
|
||||
return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x32_1x1x1") {
|
||||
return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x16_1x1x1") {
|
||||
return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, *maybe_schedule);
|
||||
}
|
||||
std::string schedule;
|
||||
int M = A.size(0);
|
||||
int K = A.size(1);
|
||||
int N = B.size(1);
|
||||
// heuristic
|
||||
if (M <= 16) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1";
|
||||
} else if (M <= 32) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1";
|
||||
} else if (M <= 64) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x64_1x1x1";
|
||||
else if (N <= 8192 && K <= 8192)
|
||||
schedule = "128x32_1x1x1";
|
||||
else
|
||||
schedule = "128x64_1x1x1";
|
||||
} else if (M <= 128) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x128_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x64_1x1x1";
|
||||
else
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 256) {
|
||||
if (N <= 4096)
|
||||
schedule = "128x64_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x128_1x1x1";
|
||||
else
|
||||
schedule = "128x256_1x1x1";
|
||||
} else if (M <= 512 && N <= 4096) {
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 1024) {
|
||||
schedule = "128x256_1x1x1";
|
||||
} else {
|
||||
schedule = "128x256_2x1x1";
|
||||
}
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, schedule);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
packed_scales.data_ptr());
|
||||
|
||||
cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel());
|
||||
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
auto layout_B = make_layout(shape_B, LayoutRight{}); // row major
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
||||
n * k);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
90
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
Normal file
90
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
Normal file
@@ -0,0 +1,90 @@
|
||||
#include "w4a8_utils.cuh"
|
||||
|
||||
#include <array>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
|
||||
namespace vllm::cutlass_w4a8_utils {
|
||||
|
||||
/*
|
||||
GPU-accelerated implementation of cutlass::unified_encode_int4b.
|
||||
Constructs a lookup table in constant memory to map 8 bits
|
||||
(two 4-bit values) at a time. Assumes memory is contiguous
|
||||
and pointers are 16-byte aligned.
|
||||
*/
|
||||
__constant__ uint8_t kNibbleLUT[256];
|
||||
|
||||
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
|
||||
size_t nbytes) {
|
||||
constexpr size_t V = sizeof(uint4); // 16 bytes
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
|
||||
const size_t nvec = nbytes / V;
|
||||
|
||||
// 1-D grid-stride loop over 16-byte chunks
|
||||
for (size_t vec = tid; vec < nvec; vec += nthreads) {
|
||||
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
|
||||
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
|
||||
reinterpret_cast<uint4*>(out)[vec] = v;
|
||||
}
|
||||
}
|
||||
|
||||
static bool upload_lut() {
|
||||
std::array<uint8_t, 256> lut{};
|
||||
auto map_nib = [](uint8_t v) -> uint8_t {
|
||||
// 1..7 -> (8 - v); keep 0 and 8..15
|
||||
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
|
||||
};
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
uint8_t lo = b & 0xF;
|
||||
uint8_t hi = (b >> 4) & 0xF;
|
||||
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
|
||||
}
|
||||
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
|
||||
/*offset=*/0, cudaMemcpyHostToDevice);
|
||||
|
||||
return (e == cudaSuccess);
|
||||
}
|
||||
|
||||
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
|
||||
size_t num_int4_elems) {
|
||||
// Build/upload LUT
|
||||
if (!upload_lut()) return false;
|
||||
|
||||
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
|
||||
"int4 storage must be 1 byte");
|
||||
const size_t nbytes = num_int4_elems >> 1;
|
||||
|
||||
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
|
||||
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
|
||||
|
||||
// kernel launch params
|
||||
constexpr int block = 256;
|
||||
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
|
||||
int grid = int((nvec + block - 1) / block);
|
||||
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
|
||||
|
||||
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
|
||||
|
||||
// launch errors
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("unified_encode_int4b_device launch error: %s (%d)\n",
|
||||
cudaGetErrorString(err), err);
|
||||
return false;
|
||||
}
|
||||
|
||||
// runtime errors
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("unified_encode_int4b_device runtime error: %s (%d)\n",
|
||||
cudaGetErrorString(err), err);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8_utils
|
||||
11
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
Normal file
11
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace vllm::cutlass_w4a8_utils {
|
||||
|
||||
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
|
||||
size_t num_int4_elems);
|
||||
|
||||
} // namespace vllm::cutlass_w4a8_utils
|
||||
Reference in New Issue
Block a user