diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 85d3d8933..d2c7975ba 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -321,6 +321,30 @@ def pack_cols( return q_res +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + def unpack_cols( packed_q_w: torch.Tensor, num_bits: int, diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 28509b60a..4b9732cce 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -254,13 +254,15 @@ set(SOURCES "csrc/gemm/per_token_quant_fp8.cu" "csrc/gemm/qserve_w4a8_per_chn_gemm.cu" "csrc/gemm/qserve_w4a8_per_group_gemm.cu" + "csrc/gemm/marlin/gptq_marlin.cu" + "csrc/gemm/marlin/gptq_marlin_repack.cu" + "csrc/gemm/marlin/awq_marlin_repack.cu" + "csrc/gemm/gptq/gptq_kernel.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" - "csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu" - "csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 90f9b843c..5835146a2 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); + // GPTQ related method + m.def( + "gptq_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none," + "Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none," + "Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k," + "bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + + m.def( + "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool " + "use_shuffle, int bit) -> Tensor"); + m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); + + m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); + m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + + m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor"); + m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + + m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); + m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); /* * From csrc/moe */ @@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); - m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor"); - m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack); - - m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); - m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack); - /* * From csrc/speculative */ diff --git a/sgl-kernel/csrc/gemm/gptq/compat.cuh b/sgl-kernel/csrc/gemm/gptq/compat.cuh new file mode 100644 index 000000000..506eeb769 --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/compat.cuh @@ -0,0 +1,62 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +namespace sglang { +namespace gptq { +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} +#endif + +#endif +#endif + +} // namespace gptq +} // namespace sglang +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu b/sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu new file mode 100644 index 000000000..4dd5d8a24 --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu @@ -0,0 +1,1950 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include +#include +#include +#include + +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_2.cuh" +#include "qdq_3.cuh" +#include "qdq_4.cuh" +#include "qdq_8.cuh" + +namespace sglang { +namespace gptq { + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_ROWS_8BIT 24 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#if defined(USE_ROCM) +#include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm( + handle, + transA, + transB, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(AP), + lda, + reinterpret_cast(BP), + ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), + ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, const float g_result, const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, const float g_result, const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float +dot22_32_f(half2 (&dq)[16], const half* a_ptr, const float g_result, const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, const half g_result, const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, const half g_result, const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, const half g_result, const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel)( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int*); + +template +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; + } + + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; + } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; +} + +void gemm_half_q_half_cuda_part( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, groups, b_q_perm); +} + +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4( + perm[lk++], + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + offset_k + lk++, + n, + __low2half(dq[0][j]), + __low2half(dq[1][j]), + __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4( + offset_k + lk++, + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4( + perm[lk++], + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + offset_k + lk++, + n, + __low2half(dq[0][j]), + __low2half(dq[1][j]), + __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4( + offset_k + lk++, + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4( + perm[lk++], + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + offset_k + lk++, + n, + __low2half(dq[0][j]), + __low2half(dq[1][j]), + __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4( + offset_k + lk++, + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4( + perm[lk++], + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4( + offset_k + lk++, + n, + __low2half(dq[0][j]), + __low2half(dq[1][j]), + __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4( + offset_k + lk++, + n, + __high2half(dq[0][j]), + __high2half(dq[1][j]), + __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +void reconstruct_exllama( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* out, + int height, + int width, + int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, out); +} + +__global__ void gemm_half_q_half_alt_4bit_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + auto val = threadIdx.x / 8; + auto off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +__global__ void gemm_half_q_half_alt_8bit_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +void gemm_half_q_half_alt( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + int size_m, + int size_n, + int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, size_m, size_k / 32 * bit, size_n); +} + +template +__global__ void reconstruct_gptq_kernel( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out) { + // Start of block + + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } +} + +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out) { + // Start of block + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); + } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } +} + +void reconstruct_gptq( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* out, + int height, + int width, + int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, width, groups, out); +} + +void gemm_half_q_half_cuda( + cublasHandle_t cublas_handle, + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + half* temp_dq, + int size_m, + int size_n, + int size_k, + int groups, + bool use_shuffle, + int bit) { + bool use_reconstruct; + if (use_shuffle) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_shuffle) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, size_k, size_n, groups, bit); + } else { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, size_k, size_n, groups, bit); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm( + cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, + size_m, + size_k, + &alpha, + temp_dq, + size_n, + a, + size_k, + &beta, + c, + size_n); + } else if (use_shuffle) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + c, + last_chunk, + size_n, + size_k, + BLOCK_M_SIZE_MAX, + groups, + bit); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part( + a + last_chunk * size_k, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + c + last_chunk * size_n, + last_chunk_size, + size_n, + size_k, + last_chunk_size, + groups, + bit); + } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, size_m, size_n, size_k, bit); + } +} + +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } +} + +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } +} + +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } +} + +__global__ void make_sequential_4bit_kernel( + const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +__global__ void make_sequential_2bit_kernel( + const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +__global__ void make_sequential_3bit_kernel( + const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, const int w_width) { + auto w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + auto w_new_row = blockIdx.y * 3; + auto q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); + } + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; +} + +__global__ void make_sequential_8bit_kernel( + const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, int width, int bit) { + if (q_perm) { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 32 * bit; + + auto kernel = make_sequential_4bit_kernel; + if (bit == 2) { + kernel = make_sequential_2bit_kernel; + } else if (bit == 3) { + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; + } else if (bit == 8) { + kernel = make_sequential_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(q_weight, new_qweight, q_perm, width); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); +} + +} // namespace gptq +} // namespace sglang + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_shuffle, + int64_t bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); + + sglang::gptq::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), + (const half*)a.data_ptr(), + (const uint32_t*)b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*)b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + (half*)c.data_ptr(), + (half*)temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_shuffle, + bit); + return c; +} + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + sglang::gptq::shuffle_exllama_weight( + (uint32_t*)q_weight.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*)q_perm.data_ptr(), + q_weight.size(0) * 32 / bit, + q_weight.size(1), + bit); +} diff --git a/sgl-kernel/csrc/gemm/gptq/matrix_view.cuh b/sgl-kernel/csrc/gemm/gptq/matrix_view.cuh new file mode 100644 index 000000000..3dfc8794c --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/matrix_view.cuh @@ -0,0 +1,269 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +namespace sglang { +namespace gptq { + +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { + return &data[row / 8 * width + column]; + } +}; + +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } +}; + +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } +}; + +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } +}; + +} // namespace gptq +} // namespace sglang +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/qdq_2.cuh b/sgl-kernel/csrc/gemm/gptq/qdq_2.cuh new file mode 100644 index 000000000..4a75d7b56 --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/qdq_2.cuh @@ -0,0 +1,74 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +namespace sglang { +namespace gptq { + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +} // namespace gptq +} // namespace sglang + +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/qdq_3.cuh b/sgl-kernel/csrc/gemm/gptq/qdq_3.cuh new file mode 100644 index 000000000..5996f342d --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/qdq_3.cuh @@ -0,0 +1,146 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +namespace sglang { +namespace gptq { +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32( + const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace gptq +} // namespace sglang + +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/qdq_4.cuh b/sgl-kernel/csrc/gemm/gptq/qdq_4.cuh new file mode 100644 index 000000000..c96af4718 --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/qdq_4.cuh @@ -0,0 +1,114 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace sglang { +namespace gptq { +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void +dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + +__forceinline__ __device__ void +dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} +} // namespace gptq +} // namespace sglang + +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/qdq_8.cuh b/sgl-kernel/csrc/gemm/gptq/qdq_8.cuh new file mode 100644 index 000000000..c6a49d6dc --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/qdq_8.cuh @@ -0,0 +1,30 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +namespace sglang { +namespace gptq { + +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void +dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) + dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) + dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +} // namespace gptq +} // namespace sglang + +#endif diff --git a/sgl-kernel/csrc/gemm/gptq/qdq_util.cuh b/sgl-kernel/csrc/gemm/gptq/qdq_util.cuh new file mode 100644 index 000000000..0977269c3 --- /dev/null +++ b/sgl-kernel/csrc/gemm/gptq/qdq_util.cuh @@ -0,0 +1,53 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +namespace sglang { +namespace gptq { + +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) { + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace gptq +} // namespace sglang +#endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu b/sgl-kernel/csrc/gemm/marlin/awq_marlin_repack.cu similarity index 82% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu rename to sgl-kernel/csrc/gemm/marlin/awq_marlin_repack.cu index 63d9dd904..1c8de396d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu +++ b/sgl-kernel/csrc/gemm/marlin/awq_marlin_repack.cu @@ -1,15 +1,12 @@ -#ifndef MARLIN_NAMESPACE_NAME -#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 -#endif - -#include "core/registration.h" -#include "gptq_marlin/marlin.cuh" -#include "kernel.h" - -namespace MARLIN_NAMESPACE_NAME { +#include "marlin.cuh" +namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async in awq_marlin_repack_kernel +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + return; +} #else template @@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel( } } } +#endif +} // namespace marlin -#define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - awq_marlin_repack_kernel \ - <<>>(b_q_weight_ptr, out_ptr, size_k, size_n); \ +#define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>(b_q_weight_ptr, out_ptr, size_k, size_n); \ } torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); - TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + TORCH_CHECK( + size_k % marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + marlin::tile_k_size); + TORCH_CHECK( + size_n % marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int const pack_factor = 32 / num_bits; @@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64 // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); - torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); + torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options); // Get ptrs uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); @@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64 return out; } - -torch::Tensor -awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); - return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options); -} - -#endif - -} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/gemm/marlin/dequant.h b/sgl-kernel/csrc/gemm/marlin/dequant.h new file mode 100644 index 000000000..e91b8b9f4 --- /dev/null +++ b/sgl-kernel/csrc/gemm/marlin/dequant.h @@ -0,0 +1,459 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu b/sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu new file mode 100644 index 000000000..b3437fc4b --- /dev/null +++ b/sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu @@ -0,0 +1,1120 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "kernel.h" +#include "marlin_template.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +} // namespace marlin + +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, + std::optional c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half* out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const& th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const sglang::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(sglang::kU4) + COMMON_GET_IF(sglang::kU4B8) + COMMON_GET_IF(sglang::kU8B128) + + FP4_GET_IF(sglang::kFE2M1f) + + BIGGROUP_GET_IF(sglang::kFE4M3fn) + + ACT_GET_IF(sglang::kU4B8) + ACT_GET_IF(sglang::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(sglang::kU4) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const sglang::ScalarType& q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} + +template +void marlin_mm( + const void* A, + const void* B, + void* C, + void* C_tmp, + void* s, + void* s2, + void* zp, + void* g_idx, + void* perm, + void* a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void* workspace, + sglang::ScalarType const& q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + TORCH_CHECK( + q_type == sglang::kU4 || q_type == sglang::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", + q_type.str()); + } else { + TORCH_CHECK( + q_type == sglang::kU4B8 || q_type == sglang::kU8B128 || q_type == sglang::kFE4M3fn || q_type == sglang::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + // avoid ">>>" being formatted to "> > >" + // clang-format off + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + // clang-format on + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + TORCH_CHECK( + false, + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + use_fp32_reduce, max_shared_mem_new); + // clang-format on + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +} // namespace marlin + +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, + std::optional c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + sglang::ScalarType const b_q_type = sglang::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); + + // Verify B + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + (size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", + b_q_weight.size(0), + ", size_k = ", + size_k, + ", tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(1) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = (b_q_weight.size(1) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.stride(1) == 1, "A.stride(1) is not 1"); + // We use int4 (16 bytes) to load A, so A must aligned to 16 bytes + TORCH_CHECK(a.stride(0) % 8 == 0, "A.stride(0) must divisible by 8"); + TORCH_CHECK(((uint64_t)a.data_ptr()) % 16 == 0, "A must aligned to 16 bytes"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m, "Shape mismatch: c.size(0) = ", c.size(0), ", size_m = ", size_m); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), ", size_n = ", size_n); + } else { + c = torch::empty({size_m, size_n}, options); + } + if (size_m == 0) return c; + + // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; + auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (use_fp32_reduce) { + int max_m_block_size = (size_m + 16 - 1) / 16 * 16; + max_m_block_size = min(max_m_block_size, 64); + int max_c_tmp_size = sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); + } else { + c_tmp = torch::empty({0}, options_fp32); + } + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + torch::Tensor g_idx, perm, a_tmp; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK( + (g_idx.size(-1) == 0 && perm.size(-1) == 0) || (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", + g_idx.size(-1), + " and perm.size(-1) = ", + perm.size(-1), + ", where size_k = ", + size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + + if (has_act_order) { + a_tmp = torch::empty({size_m, size_k}, options); + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = torch::empty({0}, options); + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == sglang::kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == sglang::kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + if (has_zp) { + TORCH_CHECK( + b_q_type == sglang::kU4 || b_q_type == sglang::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", + b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == sglang::kU4B8 || b_q_type == sglang::kU8B128 || b_q_type == sglang::kFE4M3fn || + b_q_type == sglang::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK( + a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + TORCH_CHECK( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + TORCH_CHECK( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify workspace size + TORCH_CHECK( + size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int min_workspace_size = sms; + TORCH_CHECK( + workspace.numel() >= min_workspace_size, + "workspace.numel = ", + workspace.numel(), + " is below min_workspace_size = ", + min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == sglang::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + + marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + scales_ptr, + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + size_m, + size_n, + size_k, + a.stride(0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == sglang::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + + marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + scales_ptr, + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + size_m, + size_n, + size_k, + a.stride(0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu b/sgl-kernel/csrc/gemm/marlin/gptq_marlin_repack.cu similarity index 81% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu rename to sgl-kernel/csrc/gemm/marlin/gptq_marlin_repack.cu index 209d13e06..3bdf84161 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu +++ b/sgl-kernel/csrc/gemm/marlin/gptq_marlin_repack.cu @@ -1,15 +1,17 @@ -#ifndef MARLIN_NAMESPACE_NAME -#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 -#endif - -#include "gptq_marlin/marlin.cuh" - -namespace MARLIN_NAMESPACE_NAME { +#include "marlin.cuh" +namespace marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async in gptq_marlin_repack_kernel +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + return; +} #else - template __global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, @@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel( int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); - int start_k_tile = blockIdx.x * block_k_tiles; + auto start_k_tile = blockIdx.x * block_k_tiles; if (start_k_tile >= k_tiles) { return; } @@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel( if constexpr (has_perm) { if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); @@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel( } else { if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; int first_k_packed = first_k / pack_factor; @@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel( return; } - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; if (warp_id >= 4) { return; @@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel( } } } +#endif +} // namespace marlin -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - gptq_marlin_repack_kernel \ - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + marlin::gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); - TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + TORCH_CHECK( + size_k % marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + marlin::tile_k_size); + TORCH_CHECK( + size_n % marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int const pack_factor = 32 / num_bits; @@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_ // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); - torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); + torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; @@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_ return out; } - -torch::Tensor gptq_marlin_repack_meta( - torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); - return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options); -} - -#endif - -// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { -// m.impl("gptq_marlin_repack", &gptq_marlin_repack); -// } - -// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { -// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); -// } - -} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/gemm/marlin/kernel.h b/sgl-kernel/csrc/gemm/marlin/kernel.h new file mode 100644 index 000000000..72137b244 --- /dev/null +++ b/sgl-kernel/csrc/gemm/marlin/kernel.h @@ -0,0 +1,36 @@ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace MARLIN_NAMESPACE_NAME { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh b/sgl-kernel/csrc/gemm/marlin/marlin.cuh similarity index 98% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh rename to sgl-kernel/csrc/gemm/marlin/marlin.cuh index 29db8c934..99903968d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh +++ b/sgl-kernel/csrc/gemm/marlin/marlin.cuh @@ -10,11 +10,10 @@ #include #ifndef MARLIN_NAMESPACE_NAME -#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#define MARLIN_NAMESPACE_NAME marlin #endif namespace MARLIN_NAMESPACE_NAME { - // Marlin params // 8 warps are a good choice since every SM has 4 schedulers and having more @@ -91,6 +90,7 @@ template __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } + #endif } // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh b/sgl-kernel/csrc/gemm/marlin/marlin_dtypes.cuh similarity index 97% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh rename to sgl-kernel/csrc/gemm/marlin/marlin_dtypes.cuh index 618c86f95..f1b821f10 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh +++ b/sgl-kernel/csrc/gemm/marlin/marlin_dtypes.cuh @@ -1,4 +1,3 @@ - #ifndef _data_types_cuh #define _data_types_cuh #include @@ -7,7 +6,7 @@ #include "marlin.cuh" #ifndef MARLIN_NAMESPACE_NAME -#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#define MARLIN_NAMESPACE_NAME marlin #endif namespace MARLIN_NAMESPACE_NAME { diff --git a/sgl-kernel/csrc/gemm/marlin/marlin_template.h b/sgl-kernel/csrc/gemm/marlin/marlin_template.h new file mode 100644 index 000000000..01eb33878 --- /dev/null +++ b/sgl-kernel/csrc/gemm/marlin/marlin_template.h @@ -0,0 +1,1629 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub(typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void +sub_zp(typename ScalarType::FragB& frag_b, typename ScalarType::scalar_t2& frag_zp, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = sglang::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == sglang::kU4 || w_type == sglang::kU8; + constexpr bool is_int_type = + w_type == sglang::kU4 || w_type == sglang::kU8 || w_type == sglang::kU4B8 || w_type == sglang::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == sglang::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == sglang::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh; + int4* sh_red = sh; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != sglang::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == sglang::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == sglang::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == sglang::kFE2M1f) { + res = __hmul2(res, global_scale); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h deleted file mode 100644 index 9caf0795a..000000000 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#define SGLANG_IMPLIES(p, q) (!(p) || (q)) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) - -#define _STRINGIFY(A) #A -#define STRINGIFY(A) _STRINGIFY(A) - -// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME -// could be a macro instead of a literal token. -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) - -// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME -// could be a macro instead of a literal token. -#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) - -// REGISTER_EXTENSION allows the shared library to be loaded and initialized -// via python's import statement. -#define REGISTER_EXTENSION(NAME) \ - PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ - static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ - return PyModule_Create(&module); \ - } diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h index 9b36b477d..88d157507 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h @@ -3,8 +3,8 @@ #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif -#include "gptq_marlin/marlin.cuh" -#include "gptq_marlin/marlin_dtypes.cuh" +#include "gemm/marlin/marlin.cuh" +#include "gemm/marlin/marlin_dtypes.cuh" #include "scalar_type.hpp" #define MARLIN_KERNEL_PARAMS \ diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index 14df0aee5..71c91839d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -18,13 +18,12 @@ /* * Adapted from https://github.com/IST-DASLab/marlin */ - #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif -#include "gptq_marlin/marlin.cuh" -#include "gptq_marlin/marlin_dtypes.cuh" +#include "gemm/marlin/marlin.cuh" +#include "gemm/marlin/marlin_dtypes.cuh" #include "scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index 9d5f7c26b..f430390d1 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -23,7 +23,6 @@ #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif -#include "core/registration.h" #include "kernel.h" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -50,8 +49,7 @@ __global__ void permute_cols_kernel( int size_m, int size_k, int top_k) {}; - -} // namespace marlin +} torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, diff --git a/sgl-kernel/include/scalar_type.hpp b/sgl-kernel/include/scalar_type.hpp index 8a837a2b5..9161a38b7 100644 --- a/sgl-kernel/include/scalar_type.hpp +++ b/sgl-kernel/include/scalar_type.hpp @@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); @@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 3e4d4662c..b9a694ac8 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, + std::optional c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float); + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_shuffle, + int64_t bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); + +torch::Tensor +gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); /* * From csrc/moe */ @@ -340,15 +374,6 @@ void scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); -namespace marlin_moe_wna16 { - -torch::Tensor -gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); - -} // namespace marlin_moe_wna16 - /* * From csrc/speculative */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 2abf0bcf3..ce0023b8a 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -44,6 +44,9 @@ from sgl_kernel.gemm import ( dsv3_router_gemm, fp8_blockwise_scaled_mm, fp8_scaled_mm, + gptq_gemm, + gptq_marlin_gemm, + gptq_shuffle, int8_scaled_mm, qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, diff --git a/sgl-kernel/python/sgl_kernel/fused_moe.py b/sgl-kernel/python/sgl_kernel/fused_moe.py index f24a0fa42..49b59102a 100644 --- a/sgl-kernel/python/sgl_kernel/fused_moe.py +++ b/sgl-kernel/python/sgl_kernel/fused_moe.py @@ -2,6 +2,7 @@ import functools from typing import Optional import torch +from sgl_kernel import silu_and_mul def get_scalar_type(num_bits: int, has_zp: bool): @@ -165,7 +166,7 @@ def fused_marlin_moe( is_zp_float=False, ) - torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) if expert_map is not None: intermediate_cache3.zero_() diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 7435cfdda..642bd7015 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -1,6 +1,7 @@ -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch +from sgl_kernel.scalar_type import ScalarType from sgl_kernel.utils import _get_cache_buf, get_cuda_stream @@ -353,3 +354,62 @@ def scaled_fp4_experts_quant( ) output_scales = output_scales.view(torch.float8_e4m3fn) return output, output_scales + + +# GPTQ kernels +def gptq_marlin_gemm( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return torch.ops.sgl_kernel.gptq_marlin_gemm( + a, + c, + b_q_weight, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) + + +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_shuffle: bool, + bit: int, +) -> torch.Tensor: + return torch.ops.sgl_kernel.gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit + ) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: + torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit) diff --git a/sgl-kernel/python/sgl_kernel/marlin.py b/sgl-kernel/python/sgl_kernel/marlin.py index 407834227..823d3fd4b 100644 --- a/sgl-kernel/python/sgl_kernel/marlin.py +++ b/sgl-kernel/python/sgl_kernel/marlin.py @@ -7,8 +7,8 @@ def gptq_marlin_repack( size_k, size_n, num_bits, -): - torch.ops.sgl_kernel.gptq_marlin_repack.default( +) -> torch.Tensor: + return torch.ops.sgl_kernel.gptq_marlin_repack( b_q_weight, perm, size_k, diff --git a/sgl-kernel/tests/test_gptq_kernel.py b/sgl-kernel/tests/test_gptq_kernel.py new file mode 100644 index 000000000..0ef4fad2c --- /dev/null +++ b/sgl-kernel/tests/test_gptq_kernel.py @@ -0,0 +1,131 @@ +import pytest +import torch +from sgl_kernel import gptq_gemm + +from sglang.srt.layers.quantization.utils import pack_cols, pack_rows + + +def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N): + assert bit == 4, "Reference dequantization only supports 4-bit" + group_size = K // scales.shape[0] + pack_factor = 32 // bit + + # unpack q_weight: (K//pack_factor, N) -> (K, N) + unpacked_q_weight = torch.empty( + q_weight.shape[0] * pack_factor, + q_weight.shape[1], + dtype=torch.uint8, + device=q_weight.device, + ) + for i in range(pack_factor): + unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F + + # unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N) + unpacked_q_zeros = torch.empty( + q_zeros.shape[0], + q_zeros.shape[1] * pack_factor, + dtype=torch.uint8, + device=q_zeros.device, + ) + for i in range(pack_factor): + unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F + + unpacked_q_zeros += 1 + unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype) + + scale_zeros = unpacked_q_zeros * scales # (num_groups, N) + + current_g_idx = torch.tensor( + [i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device + ) + + scale_mat = scales[current_g_idx] # (K, N) + scale_zeros_mat = scale_zeros[current_g_idx] # (K, N) + + # dequant: weight * scale - scale_zeros + dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat + + return dequantized_b.reshape(K, N) + + +def torch_gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit +): + K, N = a.shape[1], b_q_weight.shape[1] + + b_dequant = torch_dequantize( + b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N + ) + c = torch.matmul(a, b_dequant) + return c + + +def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"): + + b_fp = torch.randn(K, N, dtype=dtype, device=device) + + assert K % group_size == 0, "K must be divisible by group_size" + num_groups = K // group_size + + if use_shuffle: + return + else: + g_idx = torch.tensor( + [i // group_size for i in range(K)], dtype=torch.int32, device=device + ) + b_shuffled = b_fp[g_idx] + + b_grouped = b_shuffled.reshape(num_groups, group_size, N) + + b_max = torch.max(b_grouped, dim=1, keepdim=True)[0] + b_min = torch.min(b_grouped, dim=1, keepdim=True)[0] + + scales = (b_max - b_min) / (2**bit - 1) + scales = scales.clamp(min=1e-6) + + zeros_float = (-b_min / scales).round() + + q_b = ( + (b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8) + ) + + q_zeros_unpacked = zeros_float.to(torch.uint8) - 1 + + b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N) + + q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N) + b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N) + b_gptq_scales = scales.squeeze(1) + + a = torch.randn(M, K, dtype=dtype, device=device) + + c_ref = torch_gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit + ) + c_out = gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit + ) + + rtol = 4e-2 + atol = 4e-2 + torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol) + print( + f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}" + ) + + +@pytest.mark.parametrize("M", [1, 8, 128]) +@pytest.mark.parametrize("N", [2048, 4096]) +@pytest.mark.parametrize("K", [2048, 4096]) +@pytest.mark.parametrize("bit", [4]) +@pytest.mark.parametrize("group_size", [128]) +@pytest.mark.parametrize("use_shuffle", [False]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sgl-kernel/tests/test_marlin_gemm.py b/sgl-kernel/tests/test_marlin_gemm.py new file mode 100644 index 000000000..9c4d48a39 --- /dev/null +++ b/sgl-kernel/tests/test_marlin_gemm.py @@ -0,0 +1,121 @@ +import pytest +import torch +from sgl_kernel import gptq_marlin_gemm +from sgl_kernel.scalar_type import scalar_types + +from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace +from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), + (257, 13, 11), + (658, 13, 11), +] + + +# uint4 for awq +# uint4b8 for gptq +@pytest.mark.parametrize("k_chunk", [128]) +@pytest.mark.parametrize("n_chunk", [64, 256]) +@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", [False, True]) +@pytest.mark.parametrize("is_k_full", [False, True]) +@pytest.mark.parametrize("use_atomic_add", [False, True]) +@pytest.mark.parametrize("use_fp32_reduce", [False, True]) +def test_gptq_marlin_gemm( + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, + is_k_full, + use_atomic_add, + use_fp32_reduce, +): + m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + if has_zp: + return + + if size_k % group_size != 0: + return + + a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + if has_zp: + # AWQ style, unsigned + runtime zero-point + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size + ) + g_idx = None + sort_indices = None + marlin_s2 = None + else: + # GPTQ style, unsigned + symmetric bias + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, act_order + ) + marlin_zp = None + marlin_s2 = None + + workspace = marlin_make_workspace(w_ref.device) + + # marlin gemm + output = gptq_marlin_gemm( + a_input, + None, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + # ref gemm + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + + assert max_diff < 0.04 + + +if __name__ == "__main__": + import subprocess + + subprocess.call(["pytest", "--tb=short", str(__file__)]) diff --git a/sgl-kernel/tests/test_marlin_repack.py b/sgl-kernel/tests/test_marlin_repack.py index 698595e08..508099106 100644 --- a/sgl-kernel/tests/test_marlin_repack.py +++ b/sgl-kernel/tests/test_marlin_repack.py @@ -1,16 +1,32 @@ import numpy as np import pytest import torch -from sgl_kernel import awq_marlin_repack +from sgl_kernel import awq_marlin_repack, gptq_marlin_repack from sgl_kernel.scalar_type import scalar_types from sglang.srt.layers.quantization.utils import ( - get_pack_factor, + gptq_quantize_weights, pack_cols, + pack_rows, quantize_weights, + sort_weights, ) +from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights GPTQ_MARLIN_TILE = 16 +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 256] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), + (257, 13, 11), + (658, 13, 11), +] def awq_pack( @@ -35,70 +51,6 @@ def awq_pack( return pack_cols(q_w, num_bits, size_k, size_n) -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)]) @pytest.mark.parametrize("group_size", [16, 32]) @@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size): torch.testing.assert_close(out_gpu, q_w_marlin) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [False, True]) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_gptq_marlin_repack( + k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors +): + m_factor, n_factor, k_factor = mnk_factors + + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if size_k % group_size != 0: + pytest.skip("size_k must be divisible by group_size") + + # Create input + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order + ) + + q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + marlin_layout_perm = get_weight_perm(quant_type.size_bits) + q_w_marlin_ref = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm + ) + + # Run Marlin repack GPU kernel + q_w_marlin = gptq_marlin_repack( + q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits + ) + + torch.cuda.synchronize() + + torch.testing.assert_close(q_w_marlin, q_w_marlin_ref) + + if __name__ == "__main__": import subprocess