adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
221
sgl-kernel/csrc/gemm/awq_kernel.cu
Normal file
221
sgl-kernel/csrc/gemm/awq_kernel.cu
Normal file
@@ -0,0 +1,221 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/all.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||
// thanks to the register packing format and the fact that we force our
|
||||
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||
// the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||
// dependency if we issue immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// This is the half2 {1024, 1024} represented as an integer.
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-64, -64} represented as an integer.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#else
|
||||
assert(false);
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) {
|
||||
#if CUDA_VERSION >= 12000
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
uint4 result;
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = source;
|
||||
|
||||
// Define masks and constants
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX);
|
||||
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX);
|
||||
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX);
|
||||
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX);
|
||||
|
||||
nv_bfloat162* res = reinterpret_cast<nv_bfloat162*>(h);
|
||||
res[0] = __hfma2(
|
||||
*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[1] = __hfma2(
|
||||
*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[2] = __hfma2(
|
||||
*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[3] = __hfma2(
|
||||
*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
|
||||
return result;
|
||||
#else
|
||||
assert(false);
|
||||
return {};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename OutputT>
|
||||
__global__ void __launch_bounds__(256) dequantize_weights(
|
||||
int* __restrict__ qweight,
|
||||
OutputT* __restrict__ scales,
|
||||
int* __restrict__ qzeros,
|
||||
OutputT* __restrict__ output,
|
||||
int group_size,
|
||||
int qweight_cols,
|
||||
int qweight_rows) {
|
||||
#if CUDA_VERSION >= 12000
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
if (col >= qweight_cols || row >= qweight_rows) return;
|
||||
|
||||
int group_idx = row / group_size;
|
||||
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
|
||||
uint4 loaded_scale = *(uint4*)(scales + scale_offset);
|
||||
|
||||
// Handle different data types
|
||||
if constexpr (std::is_same<OutputT, half>::value) {
|
||||
// FP16 path
|
||||
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]);
|
||||
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
|
||||
|
||||
// Use PTX assembly for FP16 operations
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
|
||||
|
||||
OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols;
|
||||
*(uint4*)output_ptr = weight_fp16;
|
||||
} else if constexpr (std::is_same<OutputT, __nv_bfloat16>::value) {
|
||||
uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]);
|
||||
uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]);
|
||||
uint4 scale_raw = *reinterpret_cast<uint4*>(scales + scale_offset);
|
||||
|
||||
// Vectorized processing (each uint4 contains 4 nv_bfloat162)
|
||||
nv_bfloat162* weight_vec = reinterpret_cast<nv_bfloat162*>(&weight_raw);
|
||||
nv_bfloat162* zero_vec = reinterpret_cast<nv_bfloat162*>(&zero_raw);
|
||||
nv_bfloat162* scale_vec = reinterpret_cast<nv_bfloat162*>(&scale_raw);
|
||||
|
||||
// Single instruction dual-channel operation
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) { // uint4 = 4 * nv_bfloat162
|
||||
weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]);
|
||||
}
|
||||
|
||||
// Directly store to OutputT array (guaranteed contiguous memory)
|
||||
OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8;
|
||||
static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch");
|
||||
*reinterpret_cast<uint4*>(output_ptr) = weight_raw;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
|
||||
int qweight_rows = qweight.size(0);
|
||||
int qweight_cols = qweight.size(1);
|
||||
int group_size = qweight_rows / scales.size(0);
|
||||
|
||||
int x_num_threads = 16;
|
||||
int y_num_threads = 16;
|
||||
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
|
||||
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||
|
||||
auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());
|
||||
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
|
||||
|
||||
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
|
||||
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_num_threads, y_num_threads);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (scales.scalar_type() == at::ScalarType::Half) {
|
||||
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
||||
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
||||
dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||
} else {
|
||||
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
|
||||
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
|
||||
dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
76
sgl-kernel/csrc/gemm/bmm_fp8.cu
Normal file
76
sgl-kernel/csrc/gemm/bmm_fp8.cu
Normal file
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2024 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <driver_types.h>
|
||||
|
||||
#include <flashinfer/gemm/bmm_fp8.cuh>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
void bmm_fp8(
|
||||
at::Tensor A,
|
||||
at::Tensor B,
|
||||
at::Tensor D,
|
||||
at::Tensor A_scale,
|
||||
at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream) {
|
||||
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
|
||||
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
|
||||
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
|
||||
TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A");
|
||||
TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B");
|
||||
TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D");
|
||||
TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match");
|
||||
TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes");
|
||||
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2), "Result tensor has incorrect shape");
|
||||
|
||||
// PyTorch is row major by default. cuBLASLt is column major by default.
|
||||
// We need row major D as expected.
|
||||
// A ^ T * B = D, so D ^ T = B ^ T * A
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
|
||||
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
|
||||
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
|
||||
auto batch_size = A.size(0);
|
||||
auto m = A.size(1);
|
||||
auto k = A.size(2);
|
||||
auto n = B.size(2);
|
||||
|
||||
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
|
||||
auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
|
||||
workspace_buffer.data_ptr(),
|
||||
workspace_buffer.numel(),
|
||||
static_cast<b_type*>(B.data_ptr()),
|
||||
static_cast<a_type*>(A.data_ptr()),
|
||||
static_cast<d_type*>(D.data_ptr()),
|
||||
batch_size,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
static_cast<float*>(B_scale.data_ptr()),
|
||||
static_cast<float*>(A_scale.data_ptr()),
|
||||
lt_handle,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == CUBLAS_STATUS_SUCCESS, "bmm_fp8_internal_cublaslt failed: ", cublasGetStatusString(status));
|
||||
return true;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
673
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
Normal file
673
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
Normal file
@@ -0,0 +1,673 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/619709fc33bd5dc268f19d6a741fe7ed51c0f8f5/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3FusedAGemm.cu
|
||||
*
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using bf16_t = __nv_bfloat16;
|
||||
|
||||
__device__ void hmma_16_8_16_f32acc_bf16ab(
|
||||
float (&d_reg)[4], const bf16_t (&a_reg)[8], const bf16_t (&b_reg)[4], float const (&c_reg)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t a0 = *reinterpret_cast<uint32_t const*>(a_reg + 0);
|
||||
uint32_t a1 = *reinterpret_cast<uint32_t const*>(a_reg + 2);
|
||||
uint32_t a2 = *reinterpret_cast<uint32_t const*>(a_reg + 4);
|
||||
uint32_t a3 = *reinterpret_cast<uint32_t const*>(a_reg + 6);
|
||||
uint32_t b0 = *reinterpret_cast<uint32_t const*>(b_reg + 0);
|
||||
uint32_t b1 = *reinterpret_cast<uint32_t const*>(b_reg + 2);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=f"(d_reg[0]), "=f"(d_reg[1]), "=f"(d_reg[2]), "=f"(d_reg[3])
|
||||
: "r"(a0),
|
||||
"r"(a1),
|
||||
"r"(a2),
|
||||
"r"(a3),
|
||||
"r"(b0),
|
||||
"r"(b1),
|
||||
"f"(d_reg[0]),
|
||||
"f"(d_reg[1]),
|
||||
"f"(d_reg[2]),
|
||||
"f"(d_reg[3]));
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
__device__ uint32_t __nvvm_get_smem_pointer(void*);
|
||||
}
|
||||
|
||||
__device__ void ldgsts_128(void const* gPtr, void* sPtr, uint32_t pred) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
if (pred) {
|
||||
uint32_t smemPtrAsUint32 = __nvvm_get_smem_pointer(sPtr);
|
||||
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(smemPtrAsUint32), "l"(gPtr), "n"(16));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void ldsm_x4(void* smem_ptr, uint32_t* reg_ptr) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(reg_ptr[0]), "=r"(reg_ptr[1]), "=r"(reg_ptr[2]), "=r"(reg_ptr[3])
|
||||
: "r"(__nvvm_get_smem_pointer(smem_ptr)));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__device__ int apply_swizzle_343_on_elem_row_col(int row_idx_, int col_idx_) {
|
||||
uint32_t row_idx = *reinterpret_cast<uint32_t*>(&row_idx_);
|
||||
uint32_t col_idx = *reinterpret_cast<uint32_t*>(&col_idx_);
|
||||
row_idx = row_idx % 8;
|
||||
row_idx = row_idx * (16 / sizeof(Type));
|
||||
col_idx = col_idx ^ row_idx;
|
||||
return *reinterpret_cast<int*>(&col_idx);
|
||||
}
|
||||
|
||||
__device__ void initialize_barrier(
|
||||
uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int thread_count = 1) // Thread count expected to arrive/wait on this barrier
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
|
||||
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(smem_int_ptr), "r"(thread_count));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier wait
|
||||
__device__ void wait_barrier(
|
||||
uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int phase_bit) // Current phase bit the barrier waiting to flip
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra DONE;\n"
|
||||
"bra LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n" ::"r"(smem_int_ptr),
|
||||
"r"(phase_bit));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t wait_complete;
|
||||
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(wait_complete)
|
||||
: "r"(smem_int_ptr), "r"(phase_bit));
|
||||
return static_cast<bool>(wait_complete);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
// Barrier arrive
|
||||
__device__ void arrive_barrier(uint64_t* smem_barrier) // 64 bits user-manged barrier in smem
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b64 state; \n"
|
||||
"mbarrier.arrive.shared::cta.b64 state, [%0];\n"
|
||||
"}\n" ::"r"(smem_int_ptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void ldgsts_arrive(uint64_t* smem_barrier) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
|
||||
asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];" : : "r"(smem_int_ptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int gemm_k, int tile_m, int tile_k, int stage_cnt>
|
||||
struct GmemLoaderA {
|
||||
static constexpr int elem_bytes = 2;
|
||||
static constexpr int vec_bytes = 16;
|
||||
static constexpr int vec_elems = vec_bytes / elem_bytes;
|
||||
static constexpr int thread_cnt = 64;
|
||||
static_assert((tile_m * tile_k) % (vec_elems * thread_cnt) == 0);
|
||||
static constexpr int a_inst_cnt_per_iter = (tile_m * tile_k) / (vec_elems * thread_cnt);
|
||||
static_assert(gemm_k % tile_k == 0);
|
||||
static constexpr int k_iter_cnt = gemm_k / tile_k;
|
||||
|
||||
// Extra params to keep the order of k reduction...
|
||||
static constexpr int mma_warp_cnt = 4;
|
||||
static constexpr int per_mma_warp_k = tile_k / mma_warp_cnt;
|
||||
static constexpr int k_each_chunk = gemm_k / mma_warp_cnt;
|
||||
|
||||
private:
|
||||
__device__ int k_project(int tile_k_idx) {
|
||||
return (tile_k_idx / per_mma_warp_k * k_each_chunk) + (tile_k_idx % per_mma_warp_k);
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ GmemLoaderA(bf16_t const* gmem_a_local_, bf16_t* smem_a_, uint64_t* smem_barrier_)
|
||||
: gmem_a(gmem_a_local_), smem_a(smem_a_), smem_barrier(smem_barrier_), local_tid(threadIdx.x % thread_cnt) {}
|
||||
|
||||
__device__ void prepare() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
// swizzle, that's what we want.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_inst_cnt_per_iter; i++) {
|
||||
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
|
||||
int m_idx = linear_idx / tile_k;
|
||||
int k_idx = linear_idx % tile_k;
|
||||
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(m_idx, k_idx);
|
||||
a_smem_offsets[i] = m_idx * tile_k + k_idx;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void issue_mainloop() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
#pragma unroll 1
|
||||
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
|
||||
if (need_wait) {
|
||||
wait_barrier(smem_barrier + 1 + stage_idx * 2, phase_bit);
|
||||
}
|
||||
int next_stage_idx = stage_idx + 1;
|
||||
int next_phase_bit = next_stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
|
||||
next_stage_idx = next_stage_idx == stage_cnt ? 0 : next_stage_idx;
|
||||
if (loop_idx != k_iter_cnt - 1) {
|
||||
need_wait = !try_wait_barrier(smem_barrier + 1 + next_stage_idx * 2, next_phase_bit);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_inst_cnt_per_iter; i++) {
|
||||
int smem_offset = a_smem_offsets[i];
|
||||
bf16_t* smem_ptr_this_iter = smem_a + stage_idx * tile_m * tile_k + smem_offset;
|
||||
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
|
||||
int m_idx = linear_idx / tile_k;
|
||||
int k_idx = linear_idx % tile_k;
|
||||
int gmem_offset = m_idx * gemm_k + k_project(k_idx);
|
||||
bf16_t const* gmem_ptr_this_iter = gmem_a + gmem_offset;
|
||||
ldgsts_128(gmem_ptr_this_iter, smem_ptr_this_iter, true);
|
||||
}
|
||||
ldgsts_arrive(smem_barrier + stage_idx * 2);
|
||||
|
||||
stage_idx = next_stage_idx;
|
||||
phase_bit = next_phase_bit;
|
||||
gmem_a += per_mma_warp_k;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bf16_t const* gmem_a;
|
||||
bf16_t* smem_a;
|
||||
uint64_t* smem_barrier;
|
||||
int local_tid;
|
||||
int stage_idx = 0;
|
||||
int phase_bit = 1;
|
||||
bool need_wait = true;
|
||||
|
||||
// per smem_stage, store with swizzle information
|
||||
int a_smem_offsets[a_inst_cnt_per_iter];
|
||||
};
|
||||
|
||||
template <int gemm_k, int tile_n, int tile_k, int stage_cnt>
|
||||
struct GmemLoaderB {
|
||||
static constexpr int elem_bytes = 2;
|
||||
static constexpr int vec_bytes = 16;
|
||||
static constexpr int vec_elems = vec_bytes / elem_bytes;
|
||||
static constexpr int thread_cnt = 64;
|
||||
static_assert((tile_n * tile_k) % (vec_elems * thread_cnt) == 0);
|
||||
static constexpr int b_inst_cnt_per_iter = (tile_n * tile_k) / (vec_elems * thread_cnt);
|
||||
static_assert(gemm_k % tile_k == 0);
|
||||
static constexpr int k_iter_cnt = gemm_k / tile_k;
|
||||
|
||||
// Extra params to keep the order of k reduction...
|
||||
static constexpr int mma_warp_cnt = 4;
|
||||
static constexpr int per_mma_warp_k = tile_k / mma_warp_cnt;
|
||||
static constexpr int k_each_chunk = gemm_k / mma_warp_cnt;
|
||||
|
||||
private:
|
||||
__device__ int k_project(int tile_k_idx) {
|
||||
return (tile_k_idx / per_mma_warp_k * k_each_chunk) + (tile_k_idx % per_mma_warp_k);
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ GmemLoaderB(bf16_t const* gmem_b_local_, bf16_t* smem_b_, uint64_t* smem_barrier_, int gemm_n_)
|
||||
: gmem_b(gmem_b_local_),
|
||||
smem_b(smem_b_),
|
||||
smem_barrier(smem_barrier_),
|
||||
gemm_n(gemm_n_),
|
||||
local_tid(threadIdx.x % thread_cnt) {}
|
||||
|
||||
__device__ void prepare() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
// swizzle, that's what we want.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_inst_cnt_per_iter; i++) {
|
||||
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
|
||||
int n_idx = linear_idx / tile_k;
|
||||
int k_idx = linear_idx % tile_k;
|
||||
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(n_idx, k_idx);
|
||||
b_smem_offsets[i] = n_idx * tile_k + k_idx;
|
||||
preds[i] = n_idx < gemm_n;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void issue_mainloop() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#pragma unroll 1
|
||||
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
|
||||
if (need_wait) {
|
||||
wait_barrier(smem_barrier + 1 + stage_idx * 2, phase_bit);
|
||||
}
|
||||
int next_stage_idx = stage_idx + 1;
|
||||
int next_phase_bit = next_stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
|
||||
next_stage_idx = next_stage_idx == stage_cnt ? 0 : next_stage_idx;
|
||||
if (loop_idx != k_iter_cnt - 1) {
|
||||
need_wait = !try_wait_barrier(smem_barrier + 1 + next_stage_idx * 2, next_phase_bit);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_inst_cnt_per_iter; i++) {
|
||||
int smem_offset = b_smem_offsets[i];
|
||||
bf16_t* smem_ptr_this_iter = smem_b + stage_idx * tile_n * tile_k + smem_offset;
|
||||
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
|
||||
int n_idx = linear_idx / tile_k;
|
||||
int k_idx = linear_idx % tile_k;
|
||||
int gmem_offset = n_idx * gemm_k + k_project(k_idx);
|
||||
bf16_t const* gmem_ptr_this_iter = gmem_b + gmem_offset;
|
||||
ldgsts_128(gmem_ptr_this_iter, smem_ptr_this_iter, preds[i]);
|
||||
}
|
||||
ldgsts_arrive(smem_barrier + stage_idx * 2);
|
||||
|
||||
stage_idx = next_stage_idx;
|
||||
phase_bit = next_phase_bit;
|
||||
gmem_b += per_mma_warp_k;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bf16_t const* gmem_b;
|
||||
bf16_t* smem_b;
|
||||
uint64_t* smem_barrier;
|
||||
int gemm_n;
|
||||
int local_tid;
|
||||
int stage_idx = 0;
|
||||
int phase_bit = 1;
|
||||
bool need_wait = true;
|
||||
|
||||
// per smem_stage, store with swizzle information
|
||||
int b_smem_offsets[b_inst_cnt_per_iter];
|
||||
uint32_t preds[b_inst_cnt_per_iter];
|
||||
};
|
||||
|
||||
template <int gemm_m, int gemm_k, int tile_m, int tile_n, int tile_k, int stage_cnt>
|
||||
struct MmaComputer {
|
||||
static constexpr int elem_bytes = 2;
|
||||
static constexpr int thread_cnt = 128;
|
||||
static_assert(gemm_k % tile_k == 0);
|
||||
static_assert(tile_k % (thread_cnt / 32) == 0);
|
||||
static constexpr int per_warp_tile_k = tile_k / (thread_cnt / 32);
|
||||
static constexpr int k_iter_cnt = gemm_k / tile_k;
|
||||
static constexpr int k_phase_cnt = per_warp_tile_k / 16;
|
||||
static constexpr int m_iter_cnt = (tile_m + 15) / 16;
|
||||
static constexpr int n_iter_cnt = (tile_n + 7) / 8; // Possible to have non-1 n_iter_cnt for ab_swap m16 case.
|
||||
static_assert(m_iter_cnt == 1);
|
||||
static_assert(n_iter_cnt == 1 || n_iter_cnt == 2);
|
||||
|
||||
__device__ MmaComputer(
|
||||
bf16_t* gmem_c_local_, bf16_t* smem_a_, bf16_t* smem_b_, uint64_t* smem_barrier_, int warp_idx_, int gemm_n_)
|
||||
: gmem_c(gmem_c_local_),
|
||||
smem_a(smem_a_),
|
||||
smem_b(smem_b_),
|
||||
smem_barrier(smem_barrier_),
|
||||
warp_idx(warp_idx_ - (thread_cnt / 32)),
|
||||
gemm_n(gemm_n_) {}
|
||||
|
||||
private:
|
||||
__device__ constexpr int internal_b_atom_func(int tid) {
|
||||
if constexpr (tile_n < 8) {
|
||||
return (tid % tile_n) + ((tid % 8) / tile_n * 0) + tid / 8 * 8 * tile_n;
|
||||
} else {
|
||||
return (tid % 8) + ((tid % 32) / 8 * (tile_n * 8));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ void prepare() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_phase_cnt; i++) {
|
||||
int linear_idx = (lane_idx % 16) + (lane_idx / 16) * 128 + i * 256;
|
||||
int m_idx = linear_idx % tile_m;
|
||||
int k_idx = linear_idx / tile_m + warp_k_offset_in_tile_k;
|
||||
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(m_idx, k_idx);
|
||||
a_smem_offsets[0][i] = m_idx * tile_k + k_idx;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_phase_cnt; i += 2) { // Special i+=2 for B.
|
||||
int linear_idx = internal_b_atom_func(lane_idx) + i * tile_n * 16 + n_iter_idx * 8;
|
||||
int n_idx = linear_idx % tile_n;
|
||||
int k_idx = linear_idx / tile_n + warp_k_offset_in_tile_k;
|
||||
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(n_idx, k_idx);
|
||||
b_smem_offsets[n_iter_idx][i] = n_idx * tile_k + k_idx;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void issue_mainloop() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
#pragma unroll 1
|
||||
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
|
||||
wait_barrier(smem_barrier + 0 + stage_idx * 2, phase_bit);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_phase_cnt; i++) {
|
||||
int smem_offset = a_smem_offsets[0][i];
|
||||
bf16_t* smem_ptr_this_iter = smem_a + stage_idx * tile_m * tile_k + smem_offset;
|
||||
ldsm_x4(smem_ptr_this_iter, reinterpret_cast<uint32_t*>(a_reg[0][i]));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_phase_cnt; i += 2) {
|
||||
int smem_offset = b_smem_offsets[n_iter_idx][i];
|
||||
bf16_t* smem_ptr_this_iter = smem_b + stage_idx * tile_n * tile_k + smem_offset;
|
||||
ldsm_x4(smem_ptr_this_iter, reinterpret_cast<uint32_t*>(b_reg[n_iter_idx][i]));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k_iter_idx = 0; k_iter_idx < k_phase_cnt; k_iter_idx++) {
|
||||
#pragma unroll
|
||||
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
|
||||
hmma_16_8_16_f32acc_bf16ab(
|
||||
acc_reg[0][n_iter_idx], a_reg[0][k_iter_idx], b_reg[n_iter_idx][k_iter_idx], acc_reg[0][n_iter_idx]);
|
||||
}
|
||||
}
|
||||
::arrive_barrier(smem_barrier + 1 + stage_idx * 2);
|
||||
stage_idx += 1;
|
||||
phase_bit = stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
|
||||
stage_idx = stage_idx == stage_cnt ? 0 : stage_idx;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void epi() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(thread_cnt));
|
||||
// reorganize the acc_reg
|
||||
constexpr int thread_m = 2;
|
||||
constexpr int thread_n = 2 * n_iter_cnt;
|
||||
constexpr int cta_mma_n = n_iter_cnt * 8;
|
||||
float acc_reg_reorg[thread_m][thread_n];
|
||||
|
||||
for (int i = 0; i < thread_m; i++) {
|
||||
for (int j = 0; j < thread_n; j++) {
|
||||
acc_reg_reorg[i][j] = acc_reg[0][j / 2][(j % 2) + (i * 2)];
|
||||
}
|
||||
}
|
||||
|
||||
// 4 x cosize(smem_c_layout)
|
||||
float* smem_c = reinterpret_cast<float*>(smem_a);
|
||||
// coord -> index
|
||||
auto smem_c_index_func = [&](int m_idx, int n_idx) {
|
||||
int group_rows = 32 / cta_mma_n;
|
||||
int group_cnt = 2;
|
||||
return (m_idx % group_rows * cta_mma_n) + (m_idx / group_rows * (32 + group_cnt)) + n_idx;
|
||||
};
|
||||
constexpr int cosize_smem_c = ((tile_m * cta_mma_n) / 32) * (32 + 2);
|
||||
|
||||
// This should be optimized to STS.64 but can not be STS.128 due to the bank index.
|
||||
#pragma unroll
|
||||
for (int m_idx_thread = 0; m_idx_thread < thread_m; m_idx_thread++) {
|
||||
#pragma unroll
|
||||
for (int n_idx_thread = 0; n_idx_thread < thread_n; n_idx_thread++) {
|
||||
int m_idx = (lane_idx / 4) + m_idx_thread * 8;
|
||||
int n_idx = ((lane_idx % 4) * 2) + (n_idx_thread % 2) + (n_idx_thread / 2) * 8;
|
||||
smem_c[cosize_smem_c * warp_idx + smem_c_index_func(m_idx, n_idx)] = acc_reg_reorg[m_idx_thread][n_idx_thread];
|
||||
}
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(thread_cnt));
|
||||
|
||||
if (warp_idx == 0) {
|
||||
constexpr int final_acc_reg_cnt = (tile_m * tile_n + 31) / 32;
|
||||
float acc_final[final_acc_reg_cnt]{};
|
||||
|
||||
#pragma unroll
|
||||
for (int reg_idx = 0; reg_idx < final_acc_reg_cnt; reg_idx++) {
|
||||
int linear_idx = reg_idx * 32 + lane_idx;
|
||||
int m_idx = linear_idx % tile_m;
|
||||
int n_idx = linear_idx / tile_m;
|
||||
acc_final[reg_idx] += smem_c[smem_c_index_func(m_idx, n_idx) + 0 * cosize_smem_c] +
|
||||
smem_c[smem_c_index_func(m_idx, n_idx) + 1 * cosize_smem_c] +
|
||||
smem_c[smem_c_index_func(m_idx, n_idx) + 2 * cosize_smem_c] +
|
||||
smem_c[smem_c_index_func(m_idx, n_idx) + 3 * cosize_smem_c];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int reg_idx = 0; reg_idx < final_acc_reg_cnt; reg_idx++) {
|
||||
int linear_idx = reg_idx * 32 + lane_idx;
|
||||
int m_idx = linear_idx % tile_m;
|
||||
int n_idx = linear_idx / tile_m;
|
||||
if (m_idx < tile_m && n_idx < gemm_n) {
|
||||
gmem_c[n_idx * gemm_m + m_idx] = acc_final[reg_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bf16_t* gmem_c;
|
||||
bf16_t* smem_a;
|
||||
bf16_t* smem_b;
|
||||
uint64_t* smem_barrier;
|
||||
int warp_idx;
|
||||
int gemm_n;
|
||||
int stage_idx = 0;
|
||||
int phase_bit = 0;
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
int warp_k_offset_in_tile_k = warp_idx * per_warp_tile_k;
|
||||
|
||||
int a_smem_offsets[m_iter_cnt][k_phase_cnt];
|
||||
int b_smem_offsets[n_iter_cnt][k_phase_cnt];
|
||||
|
||||
bf16_t a_reg[m_iter_cnt][k_phase_cnt][8];
|
||||
bf16_t b_reg[n_iter_cnt][k_phase_cnt][4];
|
||||
float acc_reg[m_iter_cnt][n_iter_cnt][4]{};
|
||||
};
|
||||
|
||||
// AB swapped, kernel is k-major, k-major, m-major
|
||||
template <int batch_size, int gemm_m, int gemm_k, int tile_m, int tile_n, int tile_k, int stage_cnt>
|
||||
__global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
|
||||
bf16_t* output, bf16_t const* mat_a, bf16_t const* mat_b, int gemm_n) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
constexpr int load_thread_cnt = 128;
|
||||
constexpr int compute_thread_cnt = 128;
|
||||
constexpr int thread_cnt = load_thread_cnt + compute_thread_cnt;
|
||||
(void)thread_cnt;
|
||||
static_assert(gemm_m % 16 == 0);
|
||||
static_assert(gemm_k % tile_k == 0);
|
||||
static_assert(gemm_m % tile_m == 0);
|
||||
static_assert(
|
||||
tile_k == 128 || tile_k == 256 || tile_k == 512 ||
|
||||
tile_k == 1024); // tile_k must be larger than 64 since 4 warp splitK.
|
||||
static_assert(tile_m == 16);
|
||||
constexpr int g2s_vec_bytes = 16;
|
||||
constexpr int a_elem_bytes = 2;
|
||||
constexpr int b_elem_bytes = 2;
|
||||
// constexpr int c_elem_bytes = 2;
|
||||
static_assert((tile_m * a_elem_bytes + tile_n * b_elem_bytes) * tile_k * stage_cnt <= 225 * 1024);
|
||||
static_assert((tile_m * tile_k * a_elem_bytes) % (load_thread_cnt * g2s_vec_bytes) == 0);
|
||||
static_assert((tile_n * tile_k * b_elem_bytes) % (load_thread_cnt * g2s_vec_bytes) == 0);
|
||||
|
||||
extern __shared__ char smem[];
|
||||
uint64_t* smem_barrier = reinterpret_cast<uint64_t*>(smem); // producer,consumer; producer,consumer; ...
|
||||
bf16_t* smem_a = reinterpret_cast<bf16_t*>(smem + (stage_cnt * 8 * 2 + 1024) / 1024 * 1024);
|
||||
bf16_t* smem_b = smem_a + tile_m * tile_k * stage_cnt;
|
||||
|
||||
int cta_m_idx = tile_m * blockIdx.x;
|
||||
int cta_n_idx = tile_n * blockIdx.y;
|
||||
bf16_t const* gmem_a_local = mat_a + cta_m_idx * gemm_k;
|
||||
bf16_t const* gmem_b_local = mat_b + cta_n_idx * gemm_k;
|
||||
bf16_t* gmem_c_local = output + cta_n_idx * gemm_m + cta_m_idx;
|
||||
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
if (warp_idx == 4) {
|
||||
for (int i = 0; i < stage_cnt; i++) {
|
||||
initialize_barrier(smem_barrier + i * 2 + 0, load_thread_cnt); // producer
|
||||
initialize_barrier(smem_barrier + i * 2 + 1, compute_thread_cnt); // consumer
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_idx < 2) {
|
||||
GmemLoaderA<gemm_k, tile_m, tile_k, stage_cnt> a_loader(gmem_a_local, smem_a, smem_barrier);
|
||||
a_loader.prepare();
|
||||
a_loader.issue_mainloop();
|
||||
} else if (warp_idx < 4) {
|
||||
GmemLoaderB<gemm_k, tile_n, tile_k, stage_cnt> b_loader(gmem_b_local, smem_b, smem_barrier, gemm_n);
|
||||
b_loader.prepare();
|
||||
b_loader.issue_mainloop();
|
||||
} else {
|
||||
MmaComputer<gemm_m, gemm_k, tile_m, tile_n, tile_k, stage_cnt> mma_computer(
|
||||
gmem_c_local, smem_a, smem_b, smem_barrier, warp_idx, gemm_n);
|
||||
mma_computer.prepare();
|
||||
mma_computer.issue_mainloop();
|
||||
mma_computer.epi();
|
||||
}
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int kHdIn, int kHdOut, int kTileN>
|
||||
void invokeFusedAGemm(T* output, T const* mat_a, T const* mat_b, int num_tokens, cudaStream_t const stream) {
|
||||
constexpr int gemm_m = kHdOut; // 2112
|
||||
int const gemm_n = num_tokens; // 16
|
||||
constexpr int gemm_k = kHdIn; // 7168
|
||||
constexpr int batch_size = 1;
|
||||
std::swap(mat_a, mat_b);
|
||||
constexpr int tile_m = 16;
|
||||
constexpr int tile_n = kTileN; // 8 or 16
|
||||
constexpr int tile_k = std::max(256, 1024 / tile_n); // 256
|
||||
constexpr int max_stage_cnt = 1024 * 192 / ((tile_m + tile_n) * tile_k * sizeof(bf16_t));
|
||||
constexpr int k_iter_cnt = gemm_k / tile_k;
|
||||
constexpr int stage_cnt =
|
||||
k_iter_cnt > max_stage_cnt ? max_stage_cnt : k_iter_cnt; // possible tunable for smallK > 1 wave n. // 22
|
||||
int cta_m_cnt = gemm_m / tile_m;
|
||||
int cta_n_cnt = (gemm_n + tile_n - 1) / tile_n;
|
||||
constexpr int barrier_bytes = (stage_cnt * 16 + 1023) / 1024 * 1024; // 4096
|
||||
constexpr int smem_bytes = ((tile_m * 2 + tile_n * 2) * tile_k * stage_cnt + barrier_bytes + 1023) / 1024 * 1024;
|
||||
|
||||
dim3 grid(cta_m_cnt, cta_n_cnt, 1);
|
||||
dim3 block_size(256);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = grid;
|
||||
config.blockDim = block_size;
|
||||
config.dynamicSmemBytes = smem_bytes;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
if (smem_bytes >= (48 * 1024)) {
|
||||
cudaFuncSetAttribute(
|
||||
fused_a_gemm_kernel<batch_size, gemm_m, gemm_k, tile_m, tile_n, tile_k, stage_cnt>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_bytes);
|
||||
}
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
fused_a_gemm_kernel<batch_size, gemm_m, gemm_k, tile_m, tile_n, tile_k, stage_cnt>,
|
||||
output,
|
||||
mat_a,
|
||||
mat_b,
|
||||
gemm_n);
|
||||
}
|
||||
|
||||
template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 8>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens, cudaStream_t);
|
||||
|
||||
template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 16>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens, cudaStream_t);
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b) {
|
||||
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2);
|
||||
int const num_tokens = mat_a.size(0);
|
||||
int const hd_in = mat_a.size(1);
|
||||
int const hd_out = mat_b.size(1);
|
||||
|
||||
constexpr int kHdIn = 7168;
|
||||
constexpr int kHdOut = 2112;
|
||||
TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, "required 1 <= mat_a.shape[0] <= 16")
|
||||
TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168")
|
||||
TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112")
|
||||
TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]")
|
||||
TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]")
|
||||
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major
|
||||
|
||||
auto const data_type = mat_a.scalar_type();
|
||||
TORCH_CHECK(
|
||||
mat_a.scalar_type() == torch::kBFloat16 && mat_b.scalar_type() == torch::kBFloat16,
|
||||
"Only BFloat16 input dtype is supported")
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16, "Only BFloat16 output dtype is supported")
|
||||
|
||||
auto const sm = getSMVersion();
|
||||
TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
if (num_tokens <= 8) {
|
||||
invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 8>(
|
||||
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
num_tokens,
|
||||
stream);
|
||||
} else {
|
||||
invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 16>(
|
||||
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
num_tokens,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
284
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
Normal file
284
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
Normal file
@@ -0,0 +1,284 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
|
||||
*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include "cuda_runtime.h"
|
||||
#include "utils.h"
|
||||
|
||||
// Custom FMA implementation using PTX assembly instructions
|
||||
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
|
||||
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
|
||||
: "=l"(reinterpret_cast<uint64_t&>(d))
|
||||
: "l"(reinterpret_cast<uint64_t const&>(a)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(b)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(c)));
|
||||
}
|
||||
|
||||
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
|
||||
template <int VPT>
|
||||
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
|
||||
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VPT; i++) {
|
||||
dst[i] = __bfloat162float(bf16_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
__global__
|
||||
__launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(__nv_bfloat16* out, T const* mat_a, T const* mat_b) {
|
||||
// Each block handles one expert column
|
||||
int const n_idx = blockIdx.x;
|
||||
int const tid = threadIdx.x;
|
||||
constexpr int kWarpSize = 32;
|
||||
constexpr int kNumWarps = kBlockSize / kWarpSize;
|
||||
// Constants for this kernel
|
||||
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
|
||||
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations
|
||||
|
||||
// Initialize accumulators for all M rows
|
||||
float acc[kNumTokens] = {};
|
||||
|
||||
// Shared memory for warp-level reduction
|
||||
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
|
||||
|
||||
// B matrix is in column-major order, so we can directly load a column for the n_idx expert
|
||||
T const* b_col = mat_b + n_idx * kHiddenDim;
|
||||
|
||||
// Pre-compute k_base values for each iteration to help compiler optimize
|
||||
// int k_bases[k_iterations];
|
||||
int k_bases[k_iterations];
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
// Process the GEMM in chunks
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
int const k_base = k_bases[ki];
|
||||
|
||||
// Load B matrix values using vector load (8 bf16 values)
|
||||
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
|
||||
|
||||
// Convert B values to float
|
||||
float b_float[VPT];
|
||||
bf16_uint4_to_float8<VPT>(b_vec, b_float);
|
||||
|
||||
// Process each token
|
||||
#pragma unroll
|
||||
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
|
||||
// Load both rows of A matrix using vector loads
|
||||
uint4 a_vec = *reinterpret_cast<uint4 const*>(mat_a + (m_idx * kHiddenDim) + k_base);
|
||||
|
||||
// Convert A values to float
|
||||
float a_float[VPT];
|
||||
bf16_uint4_to_float8<VPT>(a_vec, a_float);
|
||||
|
||||
// Process elements in this chunk
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VPT; k++) {
|
||||
float a = a_float[k];
|
||||
float b = b_float[k];
|
||||
acc[m_idx] += a * b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform warp-level reduction
|
||||
int const warpSize = 32;
|
||||
int const warpId = tid / warpSize;
|
||||
int const laneId = tid % warpSize;
|
||||
|
||||
// Register for warp-level reduction results
|
||||
float warp_result[kNumTokens];
|
||||
|
||||
#pragma unroll
|
||||
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
|
||||
warp_result[m_idx] = acc[m_idx];
|
||||
}
|
||||
|
||||
// Perform warp-level reduction using optimized butterfly pattern
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float sum = warp_result[m];
|
||||
|
||||
// Butterfly reduction pattern
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 16);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 8);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 4);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 2);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 1);
|
||||
|
||||
// Only the first thread in each warp stores to shared memory
|
||||
if (laneId == 0) {
|
||||
sm_reduction[m][warpId] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Final reduction across warps (only first thread)
|
||||
if (tid == 0) {
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float final_sum = 0.0f;
|
||||
|
||||
// Sum across the kNumWarps
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kNumWarps; w++) {
|
||||
final_sum += sm_reduction[m][w];
|
||||
}
|
||||
|
||||
// Write final result
|
||||
out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum);
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
|
||||
constexpr int VPT = 16 / sizeof(T);
|
||||
constexpr int kBlockSize = 128;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = kNumExperts;
|
||||
config.blockDim = kBlockSize;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
router_gemm_kernel_bf16_output<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
|
||||
output,
|
||||
mat_a,
|
||||
mat_b);
|
||||
}
|
||||
|
||||
// Template instantiations for DEFAULT_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>(
|
||||
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
161
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
Normal file
161
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
Normal file
@@ -0,0 +1,161 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
|
||||
*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include "cuda_runtime.h"
|
||||
#include "utils.h"
|
||||
|
||||
static constexpr int DEFAULT_NUM_EXPERTS = 256;
|
||||
static constexpr int KIMI_K2_NUM_EXPERTS = 384;
|
||||
static constexpr int DEFAULT_HIDDEN_DIM = 7168;
|
||||
|
||||
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream);
|
||||
|
||||
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream);
|
||||
|
||||
template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
|
||||
struct LoopUnroller {
|
||||
static void unroll_float_output(
|
||||
int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
|
||||
if (num_tokens == kBegin) {
|
||||
invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
|
||||
} else {
|
||||
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll_float_output(
|
||||
num_tokens, output, input, weights, stream);
|
||||
}
|
||||
}
|
||||
|
||||
static void unroll_bf16_output(
|
||||
int num_tokens,
|
||||
__nv_bfloat16* output,
|
||||
__nv_bfloat16 const* input,
|
||||
__nv_bfloat16 const* weights,
|
||||
cudaStream_t stream) {
|
||||
if (num_tokens == kBegin) {
|
||||
invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
|
||||
} else {
|
||||
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll_bf16_output(
|
||||
num_tokens, output, input, weights, stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int kEnd, int kNumExperts, int kHiddenDim>
|
||||
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
|
||||
static void unroll_float_output(
|
||||
int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
|
||||
if (num_tokens == kEnd) {
|
||||
invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
|
||||
} else {
|
||||
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
|
||||
}
|
||||
}
|
||||
|
||||
static void unroll_bf16_output(
|
||||
int num_tokens,
|
||||
__nv_bfloat16* output,
|
||||
__nv_bfloat16 const* input,
|
||||
__nv_bfloat16 const* weights,
|
||||
cudaStream_t stream) {
|
||||
if (num_tokens == kEnd) {
|
||||
invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
|
||||
} else {
|
||||
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void dsv3_router_gemm(
|
||||
torch::Tensor& output, // [num_tokens, num_experts]
|
||||
const torch::Tensor& mat_a, // [num_tokens, hidden_dim]
|
||||
const torch::Tensor& mat_b // [num_experts, hidden_dim]
|
||||
) {
|
||||
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
|
||||
|
||||
const int num_tokens = mat_a.size(0);
|
||||
const int num_experts = mat_b.size(0);
|
||||
const int hidden_dim = mat_a.size(1);
|
||||
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim");
|
||||
TORCH_CHECK(
|
||||
hidden_dim == DEFAULT_HIDDEN_DIM,
|
||||
"Expected hidden_dim=",
|
||||
DEFAULT_HIDDEN_DIM,
|
||||
", but got hidden_dim=",
|
||||
hidden_dim);
|
||||
TORCH_CHECK(
|
||||
num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS,
|
||||
"Expected num_experts=",
|
||||
DEFAULT_NUM_EXPERTS,
|
||||
" or num_experts=",
|
||||
KIMI_K2_NUM_EXPERTS,
|
||||
", but got num_experts=",
|
||||
num_experts);
|
||||
TORCH_CHECK(
|
||||
num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm");
|
||||
TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16");
|
||||
TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16");
|
||||
TORCH_CHECK(
|
||||
output.dtype() == torch::kFloat32 || output.dtype() == torch::kBFloat16, "output must be float32 or bf16");
|
||||
|
||||
auto const sm = getSMVersion();
|
||||
TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90");
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (output.dtype() == torch::kFloat32) {
|
||||
if (num_experts == DEFAULT_NUM_EXPERTS) {
|
||||
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output(
|
||||
num_tokens,
|
||||
reinterpret_cast<float*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
stream);
|
||||
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
|
||||
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output(
|
||||
num_tokens,
|
||||
reinterpret_cast<float*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
stream);
|
||||
}
|
||||
} else if (output.dtype() == torch::kBFloat16) {
|
||||
if (num_experts == DEFAULT_NUM_EXPERTS) {
|
||||
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output(
|
||||
num_tokens,
|
||||
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
stream);
|
||||
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
|
||||
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output(
|
||||
num_tokens,
|
||||
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
283
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
Normal file
283
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
Normal file
@@ -0,0 +1,283 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
|
||||
*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include "cuda_runtime.h"
|
||||
#include "utils.h"
|
||||
|
||||
// Custom FMA implementation using PTX assembly instructions
|
||||
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
|
||||
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
|
||||
: "=l"(reinterpret_cast<uint64_t&>(d))
|
||||
: "l"(reinterpret_cast<uint64_t const&>(a)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(b)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(c)));
|
||||
}
|
||||
|
||||
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
|
||||
template <int VPT>
|
||||
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
|
||||
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VPT; i++) {
|
||||
dst[i] = __bfloat162float(bf16_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(float* out, T const* mat_a, T const* mat_b) {
|
||||
// Each block handles one expert column
|
||||
int const n_idx = blockIdx.x;
|
||||
int const tid = threadIdx.x;
|
||||
constexpr int kWarpSize = 32;
|
||||
constexpr int kNumWarps = kBlockSize / kWarpSize;
|
||||
// Constants for this kernel
|
||||
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
|
||||
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations
|
||||
|
||||
// Initialize accumulators for all M rows
|
||||
float acc[kNumTokens] = {};
|
||||
|
||||
// Shared memory for warp-level reduction
|
||||
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
|
||||
|
||||
// B matrix is in column-major order, so we can directly load a column for the n_idx expert
|
||||
T const* b_col = mat_b + n_idx * kHiddenDim;
|
||||
|
||||
// Pre-compute k_base values for each iteration to help compiler optimize
|
||||
// int k_bases[k_iterations];
|
||||
int k_bases[k_iterations];
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
// Process the GEMM in chunks
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
int const k_base = k_bases[ki];
|
||||
|
||||
// Load B matrix values using vector load (8 bf16 values)
|
||||
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
|
||||
|
||||
// Convert B values to float
|
||||
float b_float[VPT];
|
||||
bf16_uint4_to_float8<VPT>(b_vec, b_float);
|
||||
|
||||
// Process each token
|
||||
#pragma unroll
|
||||
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
|
||||
// Load both rows of A matrix using vector loads
|
||||
uint4 a_vec = *reinterpret_cast<uint4 const*>(mat_a + (m_idx * kHiddenDim) + k_base);
|
||||
|
||||
// Convert A values to float
|
||||
float a_float[VPT];
|
||||
bf16_uint4_to_float8<VPT>(a_vec, a_float);
|
||||
|
||||
// Process elements in this chunk
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VPT; k++) {
|
||||
float a = a_float[k];
|
||||
float b = b_float[k];
|
||||
acc[m_idx] += a * b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform warp-level reduction
|
||||
int const warpSize = 32;
|
||||
int const warpId = tid / warpSize;
|
||||
int const laneId = tid % warpSize;
|
||||
|
||||
// Register for warp-level reduction results
|
||||
float warp_result[kNumTokens];
|
||||
|
||||
#pragma unroll
|
||||
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
|
||||
warp_result[m_idx] = acc[m_idx];
|
||||
}
|
||||
|
||||
// Perform warp-level reduction using optimized butterfly pattern
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float sum = warp_result[m];
|
||||
|
||||
// Butterfly reduction pattern
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 16);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 8);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 4);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 2);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 1);
|
||||
|
||||
// Only the first thread in each warp stores to shared memory
|
||||
if (laneId == 0) {
|
||||
sm_reduction[m][warpId] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Final reduction across warps (only first thread)
|
||||
if (tid == 0) {
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float final_sum = 0.0f;
|
||||
|
||||
// Sum across the kNumWarps
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kNumWarps; w++) {
|
||||
final_sum += sm_reduction[m][w];
|
||||
}
|
||||
|
||||
// Write final result
|
||||
out[m * kNumExperts + n_idx] = final_sum;
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
|
||||
constexpr int VPT = 16 / sizeof(T);
|
||||
constexpr int kBlockSize = 128;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = kNumExperts;
|
||||
config.blockDim = kBlockSize;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
router_gemm_kernel_float_output<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
|
||||
output,
|
||||
mat_a,
|
||||
mat_b);
|
||||
}
|
||||
|
||||
// Template instantiations for DEFAULT_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
280
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
280
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
@@ -0,0 +1,280 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||
#include <cutlass/gemm/thread/mma.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
#include <cutlass/util/tensor_view_io.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename OutType,
|
||||
typename MmaTileShape,
|
||||
typename PerSmTileShape,
|
||||
typename EpilogueTileShape,
|
||||
typename ScalesPerTile,
|
||||
int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _1, _1>>
|
||||
void launch_sm100_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
|
||||
static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile;
|
||||
static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
|
||||
static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
using ElementA = ElementAB;
|
||||
using ElementB = ElementAB;
|
||||
using ElementC = void;
|
||||
using ElementD = OutType;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
using LayoutC = LayoutD;
|
||||
// This means both SFA and SFB are column-major.
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
ScaleGranularityK,
|
||||
cute::UMMA::Major::MN,
|
||||
cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementBlockScale = float;
|
||||
using ElementCompute = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
PerSmTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileShape,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto scales_a_ptr = static_cast<float*>(scales_a.data_ptr());
|
||||
auto scales_b_ptr = static_cast<float*>(scales_b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
using StrideD = typename GemmKernel::StrideD;
|
||||
using StrideC = typename GemmKernel::StrideD;
|
||||
|
||||
StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB};
|
||||
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
epilogue_args.thread.alpha = 1.0f;
|
||||
|
||||
typename GemmKernel::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args};
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto init_status = gemm_op.initialize(args, workspace.get());
|
||||
TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
if (a.size(0) <= 128) {
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using PerSmTileShape = Shape<_64, _128, _128>;
|
||||
using EpilogueTileShape = Shape<_64, _64>;
|
||||
using ScalesPerTile = Shape<_64, _1, _1>;
|
||||
launch_sm100_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
} else {
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using PerSmTileShape = Shape<_128, _128, _128>;
|
||||
using EpilogueTileShape = Shape<_128, _64>;
|
||||
using ScalesPerTile = Shape<_128, _1, _1>;
|
||||
launch_sm100_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major");
|
||||
TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched");
|
||||
TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
int64_t original_rows = mat_a.size(0);
|
||||
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
|
||||
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
|
||||
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version == 90) {
|
||||
torch::Tensor scales_b_contiguous = scales_b.contiguous();
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
} else {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
}
|
||||
return out_padded.slice(0, 0, original_rows);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
if (sm_version == 100
|
||||
#if CUDA_VERSION >= 12090
|
||||
|| sm_version == 103
|
||||
#endif
|
||||
) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
|
||||
} else {
|
||||
sm100_fp8_blockwise_dispatch_shape<cutlass::half_t>(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
|
||||
}
|
||||
return out_padded.slice(0, 0, original_rows);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
1252
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
1252
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||
atomicAdd_half(address, val);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||
atomicAdd_half2(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
@@ -0,0 +1,269 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and
|
||||
https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half {
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||
return __half2half2(data[row * width + column]);
|
||||
}
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw {
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||
data[row * width + column] = value;
|
||||
}
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||
((half2*)data)[(row * width + column) / 2] = value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) {
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||
return data[row / 8 * width + column];
|
||||
}
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) {
|
||||
return &data[row / 8 * width + column];
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q2_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
@@ -0,0 +1,146 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qa & 0x07;
|
||||
uint32_t t1 = (qa & 0x38) >> 3;
|
||||
qa >>= 6;
|
||||
za |= (t0 << (i * 3));
|
||||
za |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qb & 0x07;
|
||||
uint32_t t1 = (qb & 0x38) >> 3;
|
||||
qb >>= 6;
|
||||
zb |= (t0 << (i * 3));
|
||||
zb |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qc & 0x07;
|
||||
uint32_t t1 = (qc & 0x38) >> 3;
|
||||
qc >>= 6;
|
||||
zc |= (t0 << (i * 3));
|
||||
zc |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32(
|
||||
const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y8, z8);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y8, z8);
|
||||
dq[4] = __hfma2(q4.as_half2, y64, z64);
|
||||
dq[5] = __hadd2(q5.as_half2, z1);
|
||||
dq[6] = __hfma2(q6.as_half2, y8, z8);
|
||||
dq[7] = __hadd2(q7.as_half2, z1);
|
||||
dq[8] = __hfma2(q8.as_half2, y8, z8);
|
||||
dq[9] = __hfma2(q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
@@ -0,0 +1,114 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled) {
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
} else {
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32 {
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16 {
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) {
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) {
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) {
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
747
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
747
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
@@ -0,0 +1,747 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
||||
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
int NumStages>
|
||||
void cutlass_int8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
using DefaultGemmConf = cutlass::gemm::device::
|
||||
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
|
||||
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
DefaultGemmConf::kAlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
DefaultGemmConf::kAlignmentB,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
true,
|
||||
typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
|
||||
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
|
||||
cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
ElementCompute>;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
|
||||
ThreadblockShape,
|
||||
GemmKernel_::kThreadCount,
|
||||
AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
EpilogueOutputOp>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
int64_t lda = mat_a.stride(0);
|
||||
int64_t ldb = mat_b.stride(1);
|
||||
int64_t ldd = out.stride(0);
|
||||
|
||||
ElementOutput* bias_ptr = nullptr;
|
||||
int64_t ldc = 0;
|
||||
if (bias) {
|
||||
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::Params linearScalingParams;
|
||||
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
|
||||
|
||||
typename Gemm::Arguments args{
|
||||
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm75_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
if (m <= 32) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 64) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm80_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128 && n < 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm89_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 128, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
4>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
3>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
bool WithBias>
|
||||
void cutlass_int8_scaled_mm_sm90(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
// Scale
|
||||
using Compute0 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentOutput,
|
||||
EpilogueScheduleType,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
Stages,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{a_ptr, stride_a, b_ptr, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
o_ptr,
|
||||
stride_d}};
|
||||
|
||||
if constexpr (WithBias) {
|
||||
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{bias_ptr},
|
||||
{},
|
||||
};
|
||||
} else {
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{},
|
||||
};
|
||||
}
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
|
||||
void sm90_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
if (bias) {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, false>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput>
|
||||
void sm90_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 32) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _256>,
|
||||
Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 4096) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
if (sm_version >= 75 && sm_version < 80) {
|
||||
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if (sm_version == 86 || sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
} else if (sm_version == 90) {
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
// cutlass 3.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#else
|
||||
// fallback to cutlass 2.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
253
sgl-kernel/csrc/gemm/marlin/awq_marlin_repack.cu
Normal file
253
sgl-kernel/csrc/gemm/marlin/awq_marlin_repack.cu
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k, "b_q_weight.size(0) = ", b_q_weight.size(0), " is not size_k = ", size_k);
|
||||
TORCH_CHECK(
|
||||
(size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ",
|
||||
b_q_weight.size(1),
|
||||
", size_n = ",
|
||||
size_n,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
@@ -0,0 +1,459 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, sglang::ScalarTypeId w_type_id, bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), true>(int q, half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), true>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
329
sgl-kernel/csrc/gemm/marlin/gptq_marlin_repack.cu
Normal file
329
sgl-kernel/csrc/gemm/marlin/gptq_marlin_repack.cu
Normal file
@@ -0,0 +1,329 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
(size_k / pack_factor) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ",
|
||||
b_q_weight.size(0),
|
||||
", size_k = ",
|
||||
size_k,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
96
sgl-kernel/csrc/gemm/marlin/marlin.cuh
Normal file
96
sgl-kernel/csrc/gemm/marlin/marlin.cuh
Normal file
@@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) {
|
||||
return elems[i];
|
||||
}
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
82
sgl-kernel/csrc/gemm/marlin/marlin_dtypes.cuh
Normal file
82
sgl-kernel/csrc/gemm/marlin/marlin_dtypes.cuh
Normal file
@@ -0,0 +1,82 @@
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "marlin.cuh"
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
static inline constexpr auto div_ceil(A a, B b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
// Round a down to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : (a / b) * b;
|
||||
}
|
||||
|
||||
// Round a up to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : ((a / b) + 1) * b;
|
||||
}
|
||||
728
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Normal file
728
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Normal file
@@ -0,0 +1,728 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "nvfp4_quant.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float silu(const float& val) {
|
||||
return val / (1.0f + __expf(-val));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
|
||||
float2 x[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
float2 y[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
x[i] = __half22float2(x_vec.elts[i]);
|
||||
y[i] = __half22float2(y_vec.elts[i]);
|
||||
x[i].x = silu(x[i].x) * y[i].x;
|
||||
x[i].y = silu(x[i].y) * y[i].y;
|
||||
x_vec.elts[i] = __float22half2_rn(x[i]);
|
||||
} else {
|
||||
x[i] = __bfloat1622float2(x_vec.elts[i]);
|
||||
y[i] = __bfloat1622float2(y_vec.elts[i]);
|
||||
x[i].x = silu(x[i].x) * y[i].x;
|
||||
x[i].y = silu(x[i].y) * y[i].y;
|
||||
x_vec.elts[i] = __float22bfloat162_rn(x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows,
|
||||
int32_t numCols,
|
||||
Type const* in,
|
||||
float const* SFScale,
|
||||
uint32_t* out,
|
||||
uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts,
|
||||
int32_t* mask,
|
||||
int n_experts,
|
||||
bool low_latency) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
|
||||
// Input tensor row/col loops.
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// TODO(kaixih@nvidia): For now, we assume mask is used together with
|
||||
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
|
||||
// silu case, the input last dim doubles.
|
||||
bool use_mask = mask != nullptr;
|
||||
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
// Find index within the experts using different strategies based on expert
|
||||
// count
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
|
||||
if constexpr (SMALL_NUM_EXPERTS) {
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
|
||||
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
|
||||
if (rowIdx >= current_offset && rowIdx < next_offset) {
|
||||
rowIdx_in_expert = rowIdx - current_offset;
|
||||
expert_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load input offsets into registers first, then do the computation.
|
||||
// Local array size set to 17 because of register limit.
|
||||
uint32_t local_offsets[17];
|
||||
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
|
||||
*reinterpret_cast<int4*>(local_offsets) =
|
||||
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 4) =
|
||||
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 4]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 8) =
|
||||
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 8]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 12) =
|
||||
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 12]));
|
||||
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
|
||||
|
||||
// Check against the 16 loaded offsets
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
|
||||
rowIdx_in_expert = rowIdx - local_offsets[i];
|
||||
expert_idx = chunk_start + i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit when using masks.
|
||||
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
if (use_mask) {
|
||||
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
silu_and_mul(in_vec, in_vec_mul);
|
||||
}
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
|
||||
#else
|
||||
cvt_fp16_to_fp4_expert(
|
||||
#endif
|
||||
int32_t numRows,
|
||||
int32_t numCols,
|
||||
Type const* in,
|
||||
float const* SFScale,
|
||||
uint32_t* out,
|
||||
uint32_t* SFout,
|
||||
int32_t* mask,
|
||||
bool use_silu_and_mul,
|
||||
int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
|
||||
// Input tensor row/col loops.
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = (gridDim.x * blockDim.x) / n_experts;
|
||||
int remainder = (gridDim.x * blockDim.x) % n_experts;
|
||||
int expert_idx;
|
||||
int tid_in_expert;
|
||||
int actual_stride;
|
||||
if (remainder > 0) {
|
||||
int bound = remainder * (stride + 1);
|
||||
if (tid < bound) {
|
||||
expert_idx = tid / (stride + 1);
|
||||
tid_in_expert = tid % (stride + 1);
|
||||
actual_stride = stride + 1;
|
||||
} else {
|
||||
expert_idx = remainder + (tid - bound) / stride;
|
||||
tid_in_expert = (tid - bound) % stride;
|
||||
actual_stride = stride;
|
||||
}
|
||||
} else {
|
||||
expert_idx = tid / stride;
|
||||
tid_in_expert = tid % stride;
|
||||
actual_stride = stride;
|
||||
}
|
||||
int m = numRows / n_experts;
|
||||
int padded_m = (m + (128 - 1)) / 128 * 128;
|
||||
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// TODO(kaixih@nvidia): For now, we assume mask is used together with
|
||||
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
|
||||
// silu case, the input last dim doubles.
|
||||
bool use_mask = mask != nullptr;
|
||||
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
|
||||
globalIdx += actual_stride) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
// Find index within the experts
|
||||
int rowIdx_in_expert = rowIdx - expert_idx * m;
|
||||
|
||||
// Early exit when using masks.
|
||||
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
|
||||
break;
|
||||
}
|
||||
|
||||
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
if (use_silu_and_mul) {
|
||||
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
silu_and_mul(in_vec, in_vec_mul);
|
||||
}
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout;
|
||||
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(1024, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows,
|
||||
int32_t numCols,
|
||||
Type const* in,
|
||||
float const* SFScale,
|
||||
uint32_t* out,
|
||||
uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts,
|
||||
int32_t* mask,
|
||||
int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
extern __shared__ uint32_t shared_input_offsets[];
|
||||
|
||||
// Load input offsets into shared memory.
|
||||
// If n_experts is larger than 4, use vectorized int4 to save instructions.
|
||||
// If n_experts is smaller than 4, read directly.
|
||||
if constexpr (SMALL_NUM_EXPERTS) {
|
||||
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
|
||||
shared_input_offsets[i] = input_offset_by_experts[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
|
||||
*reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
bool use_mask = mask != nullptr;
|
||||
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
// Find expert using binary search for better performance with large m_topk
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
|
||||
// Binary search through experts using shared memory
|
||||
int left = 0, right = n_experts - 1;
|
||||
while (left <= right) {
|
||||
int mid = (left + right) / 2;
|
||||
// Get offsets: shared_input_offsets[i] corresponds to
|
||||
// input_offset_by_experts[i]
|
||||
uint32_t mid_offset = shared_input_offsets[mid];
|
||||
uint32_t next_offset = shared_input_offsets[mid + 1];
|
||||
|
||||
if (rowIdx >= mid_offset && rowIdx < next_offset) {
|
||||
rowIdx_in_expert = rowIdx - mid_offset;
|
||||
expert_idx = mid;
|
||||
break;
|
||||
} else if (rowIdx < mid_offset) {
|
||||
right = mid - 1;
|
||||
} else {
|
||||
left = mid + 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
|
||||
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
if (use_mask) {
|
||||
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
silu_and_mul(in_vec, in_vec_mul);
|
||||
}
|
||||
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quant_impl(
|
||||
void* output,
|
||||
void* output_scale,
|
||||
void* input,
|
||||
void* input_global_scale,
|
||||
void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts,
|
||||
void* mask,
|
||||
bool use_silu_and_mul,
|
||||
int m_topk,
|
||||
int k,
|
||||
int n_experts,
|
||||
cudaStream_t stream) {
|
||||
// TODO: this multiProcessorCount should be cached.
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiProcessorCount;
|
||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
|
||||
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
int const workSizePerRow = k / ELTS_PER_THREAD;
|
||||
int const totalWorkSize = m_topk * workSizePerRow;
|
||||
dim3 block(std::min(workSizePerRow, 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM));
|
||||
while (grid.x <= multiProcessorCount && block.x > 64) {
|
||||
grid.x *= 2;
|
||||
block.x = (block.x + 1) / 2;
|
||||
}
|
||||
|
||||
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
|
||||
if (mask != nullptr) {
|
||||
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
|
||||
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
use_silu_and_mul,
|
||||
n_experts);
|
||||
return;
|
||||
}
|
||||
|
||||
int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
|
||||
if (blockRepeat > 1) {
|
||||
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||
if (n_experts >= 4) {
|
||||
cvt_fp16_to_fp4<T, false, false><<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
n_experts);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
n_experts);
|
||||
}
|
||||
} else {
|
||||
if (n_experts >= 16) {
|
||||
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
n_experts,
|
||||
/* bool low_latency */ true);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
n_experts,
|
||||
/* bool low_latency */ true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid redefinition warnings
|
||||
#undef CHECK_CONTIGUOUS
|
||||
#undef CHECK_TH_CUDA
|
||||
#undef CHECK_INPUT
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
// constexpr auto FP8 = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
auto sm_version = getSMVersion();
|
||||
TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
|
||||
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
quant_impl<half>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // mask
|
||||
false, // use_silu_and_mul
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
quant_impl<__nv_bfloat16>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // mask
|
||||
false, // use_silu_and_mul
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul) {
|
||||
auto sm_version = getSMVersion();
|
||||
TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
|
||||
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(mask, "mask must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(mask.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k_by_2 = input.size(1);
|
||||
auto k = k_by_2;
|
||||
if (use_silu_and_mul) {
|
||||
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
|
||||
k = k_by_2 / 2;
|
||||
}
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(mask.size(0) == n_experts);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
quant_impl<half>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
nullptr, // input_offset_by_experts
|
||||
nullptr, // output_scale_offset_by_experts
|
||||
mask.data_ptr(),
|
||||
use_silu_and_mul,
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
quant_impl<__nv_bfloat16>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
nullptr, // input_offset_by_experts
|
||||
nullptr, // output_scale_offset_by_experts
|
||||
mask.data_ptr(),
|
||||
use_silu_and_mul,
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
176
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
Normal file
176
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
Normal file
@@ -0,0 +1,176 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cutlass/arch/config.h>
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
// PTX instructions used here requires sm100a/sm103a.
|
||||
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]),
|
||||
"f"(array[1]),
|
||||
"f"(array[2]),
|
||||
"f"(array[3]),
|
||||
"f"(array[4]),
|
||||
"f"(array[5]),
|
||||
"f"(array[6]),
|
||||
"f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
// PTX instructions used here requires sm100a/sm103a.
|
||||
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x),
|
||||
"f"(array[0].y),
|
||||
"f"(array[1].x),
|
||||
"f"(array[1].y),
|
||||
"f"(array[2].x),
|
||||
"f"(array[2].y),
|
||||
"f"(array[3].x),
|
||||
"f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
|
||||
innerMIdx * innerMStride + innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
74
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
Normal file
74
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
Normal file
@@ -0,0 +1,74 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void scaled_fp4_quant_sm100a(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
|
||||
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul);
|
||||
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, mask, use_silu_and_mul);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
239
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
Normal file
239
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
Normal file
@@ -0,0 +1,239 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "nvfp4_quant.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
__nv_fp8_e8m0 tmp;
|
||||
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
|
||||
SFValue = static_cast<float>(tmp);
|
||||
fp8SFVal = tmp.__x;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
fp8SFVal = tmp.__x;
|
||||
SFValue = static_cast<float>(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
|
||||
// Input tensor row/col loops.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
|
||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeFP4Quantization(
|
||||
int m,
|
||||
int n,
|
||||
T const* input,
|
||||
float const* SFScale,
|
||||
int64_t* output,
|
||||
int32_t* SFOuput,
|
||||
bool useUE8M0,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream) {
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
// Launch the cvt kernel.
|
||||
if (useUE8M0) {
|
||||
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
|
||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate the function.
|
||||
template void invokeFP4Quantization(
|
||||
int m,
|
||||
int n,
|
||||
half const* input,
|
||||
float const* SFScale,
|
||||
int64_t* output,
|
||||
int32_t* SFOuput,
|
||||
bool useUE8M0,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void invokeFP4Quantization(
|
||||
int m,
|
||||
int n,
|
||||
__nv_bfloat16 const* input,
|
||||
float const* SFScale,
|
||||
int64_t* output,
|
||||
int32_t* SFOuput,
|
||||
bool useUE8M0,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream);
|
||||
|
||||
inline int getMultiProcessorCount() {
|
||||
static int multi_processor_count = []() {
|
||||
int device_id = 0;
|
||||
int count = 0;
|
||||
|
||||
// Get the current CUDA device ID
|
||||
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
|
||||
|
||||
// Get the number of multiprocessors for the current device
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
|
||||
|
||||
return count; // Initialize the static variable
|
||||
}();
|
||||
|
||||
return multi_processor_count; // Return the cached value on subsequent calls
|
||||
}
|
||||
|
||||
void scaled_fp4_quant_sm100a(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||
auto sm_version = getSMVersion();
|
||||
TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
|
||||
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1);
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
|
||||
int multiProcessorCount = getMultiProcessorCount();
|
||||
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
// We don't support e8m0 scales at this moment.
|
||||
bool useUE8M0 = false;
|
||||
|
||||
switch (input.scalar_type()) {
|
||||
case torch::kHalf: {
|
||||
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
|
||||
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);
|
||||
break;
|
||||
}
|
||||
case torch::kBFloat16: {
|
||||
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
|
||||
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
|
||||
throw std::runtime_error("Unsupported input data type for quantize_to_fp4.");
|
||||
}
|
||||
}
|
||||
}
|
||||
39
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
Normal file
39
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
Normal file
@@ -0,0 +1,39 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void cutlass_scaled_fp4_mm_sm100a(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_fp4_mm(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel.");
|
||||
}
|
||||
369
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
Normal file
369
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
Normal file
@@ -0,0 +1,369 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// clang-format off
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = Shape<_128, _64>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 32;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
|
||||
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
|
||||
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
void,
|
||||
LayoutCTag,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutDTag,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D,
|
||||
at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
at::Tensor const& A_sf,
|
||||
at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
int k = static_cast<int>(K);
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
static_cast<ElementA const*>(A.data_ptr()),
|
||||
stride_A,
|
||||
static_cast<ElementB const*>(B.data_ptr()),
|
||||
stride_B,
|
||||
static_cast<ElementSFA const*>(A_sf.data_ptr()),
|
||||
layout_SFA,
|
||||
static_cast<ElementSFB const*>(B_sf.data_ptr()),
|
||||
layout_SFB},
|
||||
{ // Epilogue arguments
|
||||
{}, // epilogue.thread
|
||||
static_cast<ElementD const*>(D.data_ptr()),
|
||||
stride_D,
|
||||
static_cast<ElementD*>(D.data_ptr()),
|
||||
stride_D}};
|
||||
auto& fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
|
||||
} else {
|
||||
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void runGemm(
|
||||
at::Tensor& D,
|
||||
at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
at::Tensor const& A_sf,
|
||||
at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
|
||||
auto arguments = args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(
|
||||
at::Tensor& D,
|
||||
at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
at::Tensor const& A_sf,
|
||||
at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
cudaStream_t stream) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm100a(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(
|
||||
A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (",
|
||||
A.size(0),
|
||||
"x",
|
||||
A.size(1),
|
||||
" and ",
|
||||
B.size(0),
|
||||
"x",
|
||||
B.size(1),
|
||||
")");
|
||||
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(
|
||||
k % alignment == 0,
|
||||
"Expected k to be divisible by ",
|
||||
alignment,
|
||||
", but got a shape: (",
|
||||
A.size(0),
|
||||
"x",
|
||||
A.size(1),
|
||||
"), k: ",
|
||||
k,
|
||||
".");
|
||||
TORCH_CHECK(
|
||||
n % alignment == 0,
|
||||
"Expected n to be divisible by ",
|
||||
alignment,
|
||||
", but got b shape: (",
|
||||
B.size(0),
|
||||
"x",
|
||||
B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
int rounded_n = round_up(n, 128);
|
||||
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(
|
||||
A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.size(0),
|
||||
"x",
|
||||
A_sf.size(1),
|
||||
" and ",
|
||||
B_sf.size(0),
|
||||
"x",
|
||||
B_sf.size(1),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (",
|
||||
rounded_m,
|
||||
"x",
|
||||
rounded_k,
|
||||
"), but got a shape (",
|
||||
A_sf.size(0),
|
||||
"x",
|
||||
A_sf.size(1),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (",
|
||||
rounded_n,
|
||||
"x",
|
||||
rounded_k,
|
||||
"), but got a shape (",
|
||||
B_sf.size(0),
|
||||
"x",
|
||||
B_sf.size(1),
|
||||
")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
}
|
||||
}
|
||||
123
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
123
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
@@ -0,0 +1,123 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
|
||||
float max_value = 0.0f;
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / vec_size;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = static_cast<float>(input[idx]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__global__ void per_tensor_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t num_elements) {
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
const float scale_val = 1.0f / (*scale);
|
||||
|
||||
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
|
||||
// Load is already vectorized, so 16 elements work for T.
|
||||
const uint32_t VEC_SIZE = 16;
|
||||
using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / VEC_SIZE;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * VEC_SIZE);
|
||||
|
||||
DST_DTYPE output_arr[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3)
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
*(uint4*)(output + i * VEC_SIZE) = *(uint4*)output_arr;
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * VEC_SIZE;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3)
|
||||
output[idx] = static_cast<DST_DTYPE>(val);
|
||||
#else
|
||||
output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const int block_size = 256;
|
||||
const int num_elements = input.numel();
|
||||
const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (is_static == false) {
|
||||
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||
}
|
||||
|
||||
per_tensor_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
num_elements);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
240
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
Normal file
240
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
Normal file
@@ -0,0 +1,240 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename DST_DTYPE,
|
||||
bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false,
|
||||
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input,
|
||||
void* __restrict__ output_q,
|
||||
scale_packed_t* __restrict__ output_s,
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const int groups_per_block,
|
||||
const float eps,
|
||||
const float min_8bit,
|
||||
const float max_8bit,
|
||||
const int num_groups_per_row = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int row_idx = global_group_id / num_groups_per_row;
|
||||
const int col_idx_unpacked = global_group_id % num_groups_per_row;
|
||||
const int col_idx = col_idx_unpacked / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_unpacked % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
static_assert(!SCALE_UE8M0);
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = group_size / vec_size;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(group_input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
float abs_val = fabsf(val);
|
||||
local_absmax = fmaxf(local_absmax, abs_val);
|
||||
}
|
||||
}
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
|
||||
}
|
||||
|
||||
// TODO can optimize
|
||||
scale_element_t y_s_quant;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
|
||||
} else {
|
||||
y_s_quant = y_s;
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(group_input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
|
||||
group_output[i * vec_size + j] = DST_DTYPE(q_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_8bit(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double min_8bit,
|
||||
double max_8bit,
|
||||
bool scale_ue8m0 = false) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
CHECK_EQ(input.numel() % group_size, 0);
|
||||
CHECK_EQ(output_s.dim(), 2);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
||||
const int hidden_dim = input.size(input.dim() - 1);
|
||||
const int num_groups_per_row = hidden_dim / group_size;
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<uint32_t*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
num_groups_per_row, \
|
||||
scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
num_groups_per_row, \
|
||||
scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
assert(!scale_ue8m0); \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||
return true;
|
||||
} else if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
#undef LAUNCH_KERNEL
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_int8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double int8_min,
|
||||
double int8_max) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max,
|
||||
bool scale_ue8m0) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
227
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
227
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
@@ -0,0 +1,227 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Warp‑local, no shared memory
|
||||
// • One warp handles one token.
|
||||
// • Eight tokens per 256‑thread CTA.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps)
|
||||
const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31
|
||||
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||
if (token_id >= num_tokens) return;
|
||||
|
||||
// Global tensors for this token
|
||||
const T* token_input = input + token_id * hidden_dim;
|
||||
DST_DTYPE* token_output = output_q + token_id * hidden_dim;
|
||||
float* token_scale = output_s + token_id;
|
||||
|
||||
//
|
||||
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
|
||||
//
|
||||
float max_value = 0.f;
|
||||
using vec_t = flashinfer::vec_t<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_dim / kVecSize;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||
}
|
||||
}
|
||||
|
||||
float warp_max = warpReduceMax(max_value);
|
||||
|
||||
__shared__ float scale;
|
||||
scale = warp_max / FP8_E4M3_MAX;
|
||||
// Broadcast scale
|
||||
if (lane_id == 0) {
|
||||
token_scale[0] = scale;
|
||||
}
|
||||
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||
|
||||
//
|
||||
// Pass-2: quantize and write back
|
||||
//
|
||||
for (int i = lane_id; i < num_vec_elems; i += kWarpSize) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3)
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_dim;
|
||||
DST_DTYPE* token_output = output_q + token_idx * hidden_dim;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// Use template parameter for vector size
|
||||
using vec_t = flashinfer::vec_t<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_dim / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
__shared__ float scale;
|
||||
if (tid == 0) {
|
||||
scale = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_inv = 1.0f / scale;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3)
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
const auto input_sizes = input.sizes();
|
||||
const int64_t num_tokens = input_sizes[0];
|
||||
const int64_t hidden_dim = input_sizes[1];
|
||||
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_dim % 16 == 0);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (use_warp_kernel) {
|
||||
// -------- warp‑local ---------------------------------------------------
|
||||
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
710
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
Normal file
710
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
Normal file
@@ -0,0 +1,710 @@
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
// @article{yang2025lserve,
|
||||
// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention},
|
||||
// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and
|
||||
// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025}
|
||||
// }
|
||||
|
||||
// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize = \
|
||||
((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \
|
||||
sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) { \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>( \
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats, num_in_feats, num_out_channels, num_in_channels);
|
||||
|
||||
template <int N>
|
||||
__inline__ __host__ __device__ int get_log_tile(int n) {
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) {
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) {
|
||||
const int cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) {
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3])
|
||||
: "r"(((unsigned*)A_shared_warp)[0]),
|
||||
"r"(((unsigned*)A_shared_warp)[1]),
|
||||
"r"(((unsigned*)A_shared_warp)[2]),
|
||||
"r"(((unsigned*)A_shared_warp)[3]),
|
||||
"r"(((unsigned*)B_shared_warp)[0]),
|
||||
"r"(((unsigned*)B_shared_warp)[1]),
|
||||
"r"(((int*)C_warp)[0]),
|
||||
"r"(((int*)C_warp)[1]),
|
||||
"r"(((int*)C_warp)[2]),
|
||||
"r"(((int*)C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_A(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask,
|
||||
bool* preds) {
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
// *dst_ptr = *src_ptr;
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
} else {
|
||||
if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) {
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
} else {
|
||||
if (mask) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void global_to_share_one_stage_zeros(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
} else {
|
||||
if (local_mask) {
|
||||
*(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ __inline__ void
|
||||
share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void share_to_reg_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int8_t* zeros,
|
||||
int8_t* scales_i8,
|
||||
int warp_offset_m,
|
||||
int warp_offset_n,
|
||||
int k_0_0,
|
||||
int k_0_1,
|
||||
int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
uint4 loaded =
|
||||
*((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K +
|
||||
threadIdx.x);
|
||||
|
||||
auto ptr = (uint32_t*)dst + shared_iter * 8;
|
||||
ptr[0] = loaded.x & 0x0F0F0F0F;
|
||||
ptr[4] = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
ptr[2] = loaded.y & 0x0F0F0F0F;
|
||||
ptr[6] = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
ptr[1] = loaded.z & 0x0F0F0F0F;
|
||||
ptr[5] = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
ptr[3] = loaded.w & 0x0F0F0F0F;
|
||||
ptr[7] = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half2* __restrict__ w_szs,
|
||||
half* __restrict__ a_ssums,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++) {
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
|
||||
if constexpr (STAGES > 1) __pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0],
|
||||
zeros_shared,
|
||||
scales_i8_shared,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
0,
|
||||
0,
|
||||
WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2],
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2],
|
||||
zeros_shared_this_compute_stage,
|
||||
scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
k_0_0 + (iter_k == SHARED_K_ITERS - 1),
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1) {
|
||||
if constexpr (STAGES == 1) __syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2) {
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1) {
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1) {
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z) {
|
||||
if (slice_id == z) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
if (z > 0) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0) {
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M) {
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float2 w_sz = __half22float2(*(w_szs + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float a_ssum = __half2float(a_ssums[row_wb]);
|
||||
float2 psums =
|
||||
make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum;
|
||||
psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half2* __restrict__ w_szs,
|
||||
half* __restrict__ a_ssums,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
// Not implemented for SM < 800
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void qserve_w4a8_per_chn_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
const torch::Tensor& _w_szs,
|
||||
const torch::Tensor& _a_ssums,
|
||||
torch::Tensor& _out_feats) {
|
||||
// Check input tensor
|
||||
TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor");
|
||||
TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous");
|
||||
TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8");
|
||||
// Check kernel tensor
|
||||
TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor");
|
||||
TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor");
|
||||
TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous");
|
||||
TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8");
|
||||
// Check output tensor
|
||||
TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous");
|
||||
TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half");
|
||||
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
// Check matmul shape
|
||||
TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)");
|
||||
TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats");
|
||||
|
||||
// Check _ascales
|
||||
TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor");
|
||||
TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous");
|
||||
TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half");
|
||||
TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements");
|
||||
|
||||
// Check _wscales
|
||||
TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor");
|
||||
TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous");
|
||||
TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half");
|
||||
TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements");
|
||||
|
||||
// Check _w_szs
|
||||
TORCH_CHECK(_w_szs.is_cuda(), "_w_szs must be a CUDA tensor");
|
||||
TORCH_CHECK(_w_szs.is_contiguous(), "_w_szs must be contiguous");
|
||||
TORCH_CHECK(_w_szs.scalar_type() == torch::kHalf, "_w_szs must be half");
|
||||
TORCH_CHECK(_w_szs.numel() == num_out_channels, "_w_szs must have num_out_channels elements");
|
||||
|
||||
// Check _a_ssums
|
||||
TORCH_CHECK(_a_ssums.is_cuda(), "_a_ssums must be a CUDA tensor");
|
||||
TORCH_CHECK(_a_ssums.is_contiguous(), "_a_ssums must be contiguous");
|
||||
TORCH_CHECK(_a_ssums.scalar_type() == torch::kHalf, "_a_ssums must be half");
|
||||
TORCH_CHECK(_a_ssums.numel() == num_in_feats, "_a_ssums must have num_in_feats elements");
|
||||
|
||||
auto in_feats = reinterpret_cast<int8_t*>(_in_feats.data_ptr<int8_t>());
|
||||
auto kernel = reinterpret_cast<int8_t*>(_kernel.data_ptr<int8_t>());
|
||||
auto w_szs = reinterpret_cast<half2*>(_w_szs.data_ptr());
|
||||
auto a_ssums = reinterpret_cast<half*>(_a_ssums.data_ptr());
|
||||
auto wscales = reinterpret_cast<half2*>(_wscales.data_ptr());
|
||||
auto ascales = reinterpret_cast<half*>(_ascales.data_ptr());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device());
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
if (sm_version >= 80) {
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 256) {
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 128;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else if (num_out_feats >= 128) {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else {
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented qserve_w4a8_per_chn_gemm for current compute capability: ", sm_version);
|
||||
}
|
||||
return;
|
||||
}
|
||||
795
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
Normal file
795
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
Normal file
@@ -0,0 +1,795 @@
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
// @article{yang2025lserve,
|
||||
// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention},
|
||||
// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and
|
||||
// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025}
|
||||
// }
|
||||
|
||||
// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize = \
|
||||
((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \
|
||||
sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) { \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>( \
|
||||
in_feats, \
|
||||
kernel, \
|
||||
zeros, \
|
||||
scales_i8, \
|
||||
wscales, \
|
||||
ascales, \
|
||||
out_feats, \
|
||||
num_in_feats, \
|
||||
num_out_channels, \
|
||||
num_in_channels);
|
||||
|
||||
template <int N>
|
||||
__inline__ __host__ __device__ int get_log_tile(int n) {
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) {
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) {
|
||||
const int cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) {
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3])
|
||||
: "r"(((unsigned*)A_shared_warp)[0]),
|
||||
"r"(((unsigned*)A_shared_warp)[1]),
|
||||
"r"(((unsigned*)A_shared_warp)[2]),
|
||||
"r"(((unsigned*)A_shared_warp)[3]),
|
||||
"r"(((unsigned*)B_shared_warp)[0]),
|
||||
"r"(((unsigned*)B_shared_warp)[1]),
|
||||
"r"(((int*)C_warp)[0]),
|
||||
"r"(((int*)C_warp)[1]),
|
||||
"r"(((int*)C_warp)[2]),
|
||||
"r"(((int*)C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_A(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask,
|
||||
bool* preds) {
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
} else {
|
||||
if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) {
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
} else {
|
||||
if (mask) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void global_to_share_one_stage_zeros(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
} else {
|
||||
if (local_mask) {
|
||||
*(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ __inline__ void
|
||||
share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void share_to_reg_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int8_t* zeros,
|
||||
int8_t* scales_i8,
|
||||
int warp_offset_m,
|
||||
int warp_offset_n,
|
||||
int k_0_0,
|
||||
int k_0_1,
|
||||
int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
uint4 loaded =
|
||||
*((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K +
|
||||
threadIdx.x);
|
||||
uint32_t loaded_0 = loaded.x & 0x0F0F0F0F;
|
||||
uint32_t loaded_4 = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_2 = loaded.y & 0x0F0F0F0F;
|
||||
uint32_t loaded_6 = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_1 = loaded.z & 0x0F0F0F0F;
|
||||
uint32_t loaded_5 = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_3 = loaded.w & 0x0F0F0F0F;
|
||||
uint32_t loaded_7 = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
|
||||
auto ptr = (uint32_t*)dst + shared_iter * 8;
|
||||
int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32;
|
||||
uint32_t packed_scales = *reinterpret_cast<uint32_t*>(scales_i8 + scales_zeros_offset);
|
||||
uint32_t packed_zeros = *reinterpret_cast<uint32_t*>(zeros + scales_zeros_offset);
|
||||
|
||||
uint32_t scale_0 = packed_scales & 0xFF;
|
||||
uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000);
|
||||
uint32_t ptr_0 = loaded_0 * scale_0;
|
||||
uint32_t ptr_1 = loaded_1 * scale_0;
|
||||
ptr[0] = __vadd4(ptr_0, zero_point_0);
|
||||
ptr[1] = __vadd4(ptr_1, zero_point_0);
|
||||
|
||||
uint32_t scale_1 = (packed_scales & 0xFF00) >> 8;
|
||||
uint32_t zero_point_1 = __byte_perm(packed_zeros, 0, 0x00001111);
|
||||
uint32_t ptr_2 = loaded_2 * scale_1;
|
||||
uint32_t ptr_3 = loaded_3 * scale_1;
|
||||
ptr[2] = __vadd4(ptr_2, zero_point_1);
|
||||
ptr[3] = __vadd4(ptr_3, zero_point_1);
|
||||
|
||||
uint32_t scale_2 = (packed_scales & 0xFF0000) >> 16;
|
||||
uint32_t zero_point_2 = __byte_perm(packed_zeros, 0, 0x00002222);
|
||||
uint32_t ptr_4 = loaded_4 * scale_2;
|
||||
uint32_t ptr_5 = loaded_5 * scale_2;
|
||||
ptr[4] = __vadd4(ptr_4, zero_point_2);
|
||||
ptr[5] = __vadd4(ptr_5, zero_point_2);
|
||||
|
||||
uint32_t scale_3 = (packed_scales & 0xFF000000) >> 24;
|
||||
uint32_t zero_point_3 = __byte_perm(packed_zeros, 0, 0x00003333);
|
||||
uint32_t ptr_6 = loaded_6 * scale_3;
|
||||
uint32_t ptr_7 = loaded_7 * scale_3;
|
||||
ptr[6] = __vadd4(ptr_6, zero_point_3);
|
||||
ptr[7] = __vadd4(ptr_7, zero_point_3);
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
int8_t* __restrict__ zeros,
|
||||
int8_t* __restrict__ scales_i8,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++) {
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
zeros, zeros_shared + (k_0_0_ld)*CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
scales_i8,
|
||||
scales_i8_shared + (k_0_0_ld)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
k_0_0_ld < gemm_iters);
|
||||
|
||||
if constexpr (STAGES > 1) __pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0],
|
||||
zeros_shared,
|
||||
scales_i8_shared,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
0,
|
||||
0,
|
||||
WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2],
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2],
|
||||
zeros_shared_this_compute_stage,
|
||||
scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
k_0_0 + (iter_k == SHARED_K_ITERS - 1),
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1) {
|
||||
if constexpr (STAGES == 1) __syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2) {
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
zeros,
|
||||
zeros_shared + (ld_stage)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
scales_i8,
|
||||
scales_i8_shared + (ld_stage)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1) {
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1) {
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z) {
|
||||
if (slice_id == z) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
if (z > 0) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0) {
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M) {
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float2 psums =
|
||||
make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x *= wscale.x * ascale;
|
||||
psums.y *= wscale.y * ascale;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
int8_t* __restrict__ zeros,
|
||||
int8_t* __restrict__ scales_i8,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
// Not implemented for SM < 800
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void qserve_w4a8_per_group_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _zeros,
|
||||
const torch::Tensor& _scales_i8,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
torch::Tensor& _out_feats) {
|
||||
// Check input tensor
|
||||
TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor");
|
||||
TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous");
|
||||
TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8");
|
||||
// Check kernel tensor
|
||||
TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor");
|
||||
TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor");
|
||||
TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous");
|
||||
TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8");
|
||||
// Check output tensor
|
||||
TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous");
|
||||
TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half");
|
||||
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
// Check matmul shape
|
||||
TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)");
|
||||
TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats");
|
||||
|
||||
// Check _ascales
|
||||
TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor");
|
||||
TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous");
|
||||
TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half");
|
||||
TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements");
|
||||
|
||||
// Check _wscales
|
||||
TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor");
|
||||
TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous");
|
||||
TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half");
|
||||
TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements");
|
||||
|
||||
// Check _scales_i8
|
||||
TORCH_CHECK(_scales_i8.is_cuda(), "_scales_i8 must be a CUDA tensor");
|
||||
TORCH_CHECK(_scales_i8.dim() == 2, "_scales_i8 must be a 2D tensor");
|
||||
TORCH_CHECK(_scales_i8.is_contiguous(), "_scales_i8 must be contiguous");
|
||||
TORCH_CHECK(_scales_i8.scalar_type() == torch::kInt8, "_scales_i8 must be int8");
|
||||
TORCH_CHECK(num_in_channels % _scales_i8.size(0) == 0, "num_in_channels must be divisible by _scales_i8.size(0)");
|
||||
TORCH_CHECK(num_out_channels == _scales_i8.size(1), "num_out_channels must be equal to _scales_i8.size(1)");
|
||||
|
||||
// Check _zeros
|
||||
TORCH_CHECK(_zeros.is_cuda(), "_zeros must be a CUDA tensor");
|
||||
TORCH_CHECK(_zeros.dim() == 2, "_zeros must be a 2D tensor");
|
||||
TORCH_CHECK(_zeros.is_contiguous(), "_zeros must be contiguous");
|
||||
TORCH_CHECK(_zeros.scalar_type() == torch::kInt8, "_zeros must be int8");
|
||||
TORCH_CHECK(num_in_channels % _zeros.size(0) == 0, "num_in_channels must be divisible by _zeros.size(0)");
|
||||
TORCH_CHECK(num_out_channels == _zeros.size(1), "num_out_channels must be equal to _zeros.size(1)");
|
||||
|
||||
// Check group size
|
||||
auto group_size = num_in_channels / _scales_i8.size(0);
|
||||
TORCH_CHECK(group_size == 128, "group_size must be 128");
|
||||
|
||||
auto in_feats = reinterpret_cast<int8_t*>(_in_feats.data_ptr<int8_t>());
|
||||
auto kernel = reinterpret_cast<int8_t*>(_kernel.data_ptr<int8_t>());
|
||||
auto zeros = reinterpret_cast<int8_t*>(_zeros.data_ptr<int8_t>());
|
||||
auto scales_i8 = reinterpret_cast<int8_t*>(_scales_i8.data_ptr<int8_t>());
|
||||
auto wscales = reinterpret_cast<half2*>(_wscales.data_ptr());
|
||||
auto ascales = reinterpret_cast<half*>(_ascales.data_ptr());
|
||||
// auto options =
|
||||
// torch::TensorOptions().dtype(torch::kHalf).device(_in_feats.device());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device());
|
||||
auto sm_version = getSMVersion();
|
||||
if (sm_version >= 80) {
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 128) {
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else if (num_out_feats >= 128) {
|
||||
if (num_in_channels <= 4096) {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented qserve_w4a8_per_group_gemm for current compute capability: ", sm_version);
|
||||
}
|
||||
return;
|
||||
}
|
||||
Reference in New Issue
Block a user