[2/n]decouple quantization implementation from vLLM dependency (#8112)
Co-authored-by: walker-ai <yiyun.wyt@antgroup.com> Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
@@ -321,6 +321,30 @@ def pack_cols(
|
||||
return q_res
|
||||
|
||||
|
||||
def pack_rows(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_k % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
|
||||
@@ -254,13 +254,15 @@ set(SOURCES
|
||||
"csrc/gemm/per_token_quant_fp8.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
|
||||
"csrc/gemm/marlin/gptq_marlin.cu"
|
||||
"csrc/gemm/marlin/gptq_marlin_repack.cu"
|
||||
"csrc/gemm/marlin/awq_marlin_repack.cu"
|
||||
"csrc/gemm/gptq/gptq_kernel.cu"
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
|
||||
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
|
||||
|
||||
@@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
||||
|
||||
// GPTQ related method
|
||||
m.def(
|
||||
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
||||
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
|
||||
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
|
||||
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||
|
||||
m.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
|
||||
"use_shuffle, int bit) -> Tensor");
|
||||
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||
atomicAdd_half(address, val);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||
atomicAdd_half2(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
@@ -0,0 +1,269 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and
|
||||
https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half {
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||
return __half2half2(data[row * width + column]);
|
||||
}
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw {
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||
data[row * width + column] = value;
|
||||
}
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||
((half2*)data)[(row * width + column) / 2] = value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) {
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||
return data[row / 8 * width + column];
|
||||
}
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) {
|
||||
return &data[row / 8 * width + column];
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q2_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
@@ -0,0 +1,146 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qa & 0x07;
|
||||
uint32_t t1 = (qa & 0x38) >> 3;
|
||||
qa >>= 6;
|
||||
za |= (t0 << (i * 3));
|
||||
za |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qb & 0x07;
|
||||
uint32_t t1 = (qb & 0x38) >> 3;
|
||||
qb >>= 6;
|
||||
zb |= (t0 << (i * 3));
|
||||
zb |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qc & 0x07;
|
||||
uint32_t t1 = (qc & 0x38) >> 3;
|
||||
qc >>= 6;
|
||||
zc |= (t0 << (i * 3));
|
||||
zc |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32(
|
||||
const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y8, z8);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y8, z8);
|
||||
dq[4] = __hfma2(q4.as_half2, y64, z64);
|
||||
dq[5] = __hadd2(q5.as_half2, z1);
|
||||
dq[6] = __hfma2(q6.as_half2, y8, z8);
|
||||
dq[7] = __hadd2(q7.as_half2, z1);
|
||||
dq[8] = __hfma2(q8.as_half2, y8, z8);
|
||||
dq[9] = __hfma2(q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
@@ -0,0 +1,114 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled) {
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
} else {
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32 {
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16 {
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) {
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) {
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) {
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
@@ -1,15 +1,12 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "kernel.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in awq_marlin_repack_kernel
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
@@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
@@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
@@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
@@ -0,0 +1,459 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, sglang::ScalarTypeId w_type_id, bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), true>(int q, half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), true>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,17 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in gptq_marlin_repack_kernel
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
@@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
@@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
@@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
@@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
@@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(
|
||||
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
// }
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
// }
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
@@ -10,11 +10,10 @@
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
@@ -91,6 +90,7 @@ template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include <cuda_bf16.h>
|
||||
@@ -7,7 +6,7 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,25 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "gemm/marlin/marlin.cuh"
|
||||
#include "gemm/marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
|
||||
@@ -18,13 +18,12 @@
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "gemm/marlin/marlin.cuh"
|
||||
#include "gemm/marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "kernel.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel(
|
||||
int size_m,
|
||||
int size_k,
|
||||
int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
|
||||
@@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
@@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
|
||||
@@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_shuffle,
|
||||
int64_t bit);
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -340,15 +374,6 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
namespace marlin_moe_wna16 {
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
} // namespace marlin_moe_wna16
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -44,6 +44,9 @@ from sgl_kernel.gemm import (
|
||||
dsv3_router_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
gptq_gemm,
|
||||
gptq_marlin_gemm,
|
||||
gptq_shuffle,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
|
||||
@@ -2,6 +2,7 @@ import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
@@ -165,7 +166,7 @@ def fused_marlin_moe(
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant(
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# GPTQ kernels
|
||||
def gptq_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
b_scales,
|
||||
global_scale,
|
||||
b_zeros,
|
||||
g_idx,
|
||||
perm,
|
||||
workspace,
|
||||
b_q_type.id,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
is_zp_float,
|
||||
)
|
||||
|
||||
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor,
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_shuffle: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
|
||||
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
@@ -7,8 +7,8 @@ def gptq_marlin_repack(
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
):
|
||||
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
|
||||
131
sgl-kernel/tests/test_gptq_kernel.py
Normal file
131
sgl-kernel/tests/test_gptq_kernel.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import gptq_gemm
|
||||
|
||||
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
|
||||
|
||||
|
||||
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
|
||||
assert bit == 4, "Reference dequantization only supports 4-bit"
|
||||
group_size = K // scales.shape[0]
|
||||
pack_factor = 32 // bit
|
||||
|
||||
# unpack q_weight: (K//pack_factor, N) -> (K, N)
|
||||
unpacked_q_weight = torch.empty(
|
||||
q_weight.shape[0] * pack_factor,
|
||||
q_weight.shape[1],
|
||||
dtype=torch.uint8,
|
||||
device=q_weight.device,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
|
||||
|
||||
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
|
||||
unpacked_q_zeros = torch.empty(
|
||||
q_zeros.shape[0],
|
||||
q_zeros.shape[1] * pack_factor,
|
||||
dtype=torch.uint8,
|
||||
device=q_zeros.device,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
|
||||
|
||||
unpacked_q_zeros += 1
|
||||
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
|
||||
|
||||
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
|
||||
|
||||
current_g_idx = torch.tensor(
|
||||
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
|
||||
)
|
||||
|
||||
scale_mat = scales[current_g_idx] # (K, N)
|
||||
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
|
||||
|
||||
# dequant: weight * scale - scale_zeros
|
||||
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
|
||||
|
||||
return dequantized_b.reshape(K, N)
|
||||
|
||||
|
||||
def torch_gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
):
|
||||
K, N = a.shape[1], b_q_weight.shape[1]
|
||||
|
||||
b_dequant = torch_dequantize(
|
||||
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
|
||||
)
|
||||
c = torch.matmul(a, b_dequant)
|
||||
return c
|
||||
|
||||
|
||||
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
|
||||
|
||||
b_fp = torch.randn(K, N, dtype=dtype, device=device)
|
||||
|
||||
assert K % group_size == 0, "K must be divisible by group_size"
|
||||
num_groups = K // group_size
|
||||
|
||||
if use_shuffle:
|
||||
return
|
||||
else:
|
||||
g_idx = torch.tensor(
|
||||
[i // group_size for i in range(K)], dtype=torch.int32, device=device
|
||||
)
|
||||
b_shuffled = b_fp[g_idx]
|
||||
|
||||
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
|
||||
|
||||
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
|
||||
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
|
||||
|
||||
scales = (b_max - b_min) / (2**bit - 1)
|
||||
scales = scales.clamp(min=1e-6)
|
||||
|
||||
zeros_float = (-b_min / scales).round()
|
||||
|
||||
q_b = (
|
||||
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
|
||||
)
|
||||
|
||||
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
|
||||
|
||||
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
|
||||
|
||||
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
|
||||
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
|
||||
b_gptq_scales = scales.squeeze(1)
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype, device=device)
|
||||
|
||||
c_ref = torch_gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
|
||||
)
|
||||
c_out = gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
rtol = 4e-2
|
||||
atol = 4e-2
|
||||
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
|
||||
print(
|
||||
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 8, 128])
|
||||
@pytest.mark.parametrize("N", [2048, 4096])
|
||||
@pytest.mark.parametrize("K", [2048, 4096])
|
||||
@pytest.mark.parametrize("bit", [4])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("use_shuffle", [False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import gptq_marlin_gemm
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
|
||||
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
# uint4 for awq
|
||||
# uint4b8 for gptq
|
||||
@pytest.mark.parametrize("k_chunk", [128])
|
||||
@pytest.mark.parametrize("n_chunk", [64, 256])
|
||||
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", [False, True])
|
||||
@pytest.mark.parametrize("is_k_full", [False, True])
|
||||
@pytest.mark.parametrize("use_atomic_add", [False, True])
|
||||
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace(w_ref.device)
|
||||
|
||||
# marlin gemm
|
||||
output = gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
# ref gemm
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
@@ -1,16 +1,32 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_marlin_repack
|
||||
from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
get_pack_factor,
|
||||
gptq_quantize_weights,
|
||||
pack_cols,
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
def awq_pack(
|
||||
@@ -35,70 +51,6 @@ def awq_pack(
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(np.uint32)
|
||||
|
||||
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("group_size", [16, 32])
|
||||
@@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
|
||||
torch.testing.assert_close(out_gpu, q_w_marlin)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [False, True])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
if size_k % group_size != 0:
|
||||
pytest.skip("size_k must be divisible by group_size")
|
||||
|
||||
# Create input
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
|
||||
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
|
||||
q_w_marlin_ref = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
q_w_marlin = gptq_marlin_repack(
|
||||
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
|
||||
Reference in New Issue
Block a user