[2/n]decouple quantization implementation from vLLM dependency (#8112)

Co-authored-by: walker-ai <yiyun.wyt@antgroup.com>
Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
Peng Zhang
2025-08-14 18:19:03 +08:00
committed by GitHub
parent 4dbf43601d
commit 5aa1ebd242
32 changed files with 6506 additions and 202 deletions

View File

@@ -321,6 +321,30 @@ def pack_cols(
return q_res
def pack_rows(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[i::pack_factor, :] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,

View File

@@ -254,13 +254,15 @@ set(SOURCES
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
"csrc/gemm/marlin/gptq_marlin.cu"
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"

View File

@@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// GPTQ related method
m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
m.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
"use_shuffle, int bit) -> Tensor");
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/*
* From csrc/moe
*/
@@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
/*
* From csrc/speculative
*/

View File

@@ -0,0 +1,62 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace sglang {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
} while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val);
}
#endif
#endif
#endif
} // namespace gptq
} // namespace sglang
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,269 @@
/*
Adapted from https://github.com/turboderp/exllamav2 and
https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
class MatrixView_half {
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const {
return data[row * width + column];
}
__device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
return __half2half2(data[row * width + column]);
}
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
return &data[row * width + column];
}
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw {
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const {
return data[row * width + column];
}
__device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half* item_ptr(int row, int column) {
return &data[row * width + column];
}
__device__ __forceinline__ void set(int row, int column, half value) {
data[row * width + column] = value;
}
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
((half2*)data)[(row * width + column) / 2] = value;
}
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) {
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*)item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
class MatrixView_q4_column {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
return data[row / 8 * width + column];
}
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) {
return &data[row / 8 * width + column];
}
};
class MatrixView_q2_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -0,0 +1,74 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -0,0 +1,146 @@
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) {
uint32_t t0 = qa & 0x07;
uint32_t t1 = (qa & 0x38) >> 3;
qa >>= 6;
za |= (t0 << (i * 3));
za |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qb & 0x07;
uint32_t t1 = (qb & 0x38) >> 3;
qb >>= 6;
zb |= (t0 << (i * 3));
zb |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qc & 0x07;
uint32_t t1 = (qc & 0x38) >> 3;
qc >>= 6;
zc |= (t0 << (i * 3));
zc |= (t1 << (i * 3 + 16));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32(
const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y8, z8);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y8, z8);
dq[4] = __hfma2(q4.as_half2, y64, z64);
dq[5] = __hadd2(q5.as_half2, z1);
dq[6] = __hfma2(q6.as_half2, y8, z8);
dq[7] = __hadd2(q7.as_half2, z1);
dq[8] = __hfma2(q8.as_half2, y8, z8);
dq[9] = __hfma2(q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -0,0 +1,114 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void
dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void
dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) {
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled) {
dq[0] = __hfma2(q0.as_half2, y1y16[0],
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
} else {
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1],
z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -0,0 +1,30 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
__forceinline__ __device__ void
dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) {
half dqh[8];
for (int i = 0; i < 4; i++)
dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++)
dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++)
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -0,0 +1,53 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace sglang {
namespace gptq {
union half2_uint32 {
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16 {
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) {
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) {
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) {
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace gptq
} // namespace sglang
#endif

View File

@@ -1,15 +1,12 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "gptq_marlin/marlin.cuh"
#include "kernel.h"
namespace MARLIN_NAMESPACE_NAME {
#include "marlin.cuh"
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in awq_marlin_repack_kernel
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
return;
}
#else
template <int const num_threads, int const num_bits>
@@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel(
}
}
}
#endif
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
TORCH_CHECK(
size_k % marlin::tile_k_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_k_size = ",
marlin::tile_k_size);
TORCH_CHECK(
size_n % marlin::tile_n_size == 0,
"size_n = ",
size_n,
" is not divisible by tile_n_size = ",
marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
@@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
// Get ptrs
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
@@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
return out;
}
torch::Tensor
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

View File

@@ -0,0 +1,459 @@
/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, sglang::ScalarTypeId w_type_id, bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, sglang::kU4B8.id(), true>(int q, half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, sglang::kU4B8.id(), false>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, sglang::kU4.id(), true>(int q, half2* frag_b) {
dequant<half2, sglang::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, sglang::kU4.id(), false>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43084308;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), true>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43004300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, sglang::kU8B128.id(), true>(int q, half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, sglang::kU8B128.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, sglang::kU8.id(), true>(int q, half2* frag_b) {
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, sglang::kU8.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU8.id(), false>(int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), true>(int q, half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), true>(int q, half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and BF16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
int Out2 = (q & 0xFF00FF00) >> 1;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,17 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
namespace MARLIN_NAMESPACE_NAME {
#include "marlin.cuh"
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in gptq_marlin_repack_kernel
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr,
uint32_t* __restrict__ out_ptr,
int size_k,
int size_n) {
return;
}
#else
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
@@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
@@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel(
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
@@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel(
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
@@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel(
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
@@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel(
}
}
}
#endif
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
TORCH_CHECK(
size_k % marlin::tile_k_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_k_size = ",
marlin::tile_k_size);
TORCH_CHECK(
size_n % marlin::tile_n_size == 0,
"size_n = ",
size_n,
" is not divisible by tile_n_size = ",
marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
@@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
@@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
return out;
}
torch::Tensor gptq_marlin_repack_meta(
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
// }
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
// }
} // namespace MARLIN_NAMESPACE_NAME

View File

@@ -0,0 +1,36 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

View File

@@ -10,11 +10,10 @@
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more
@@ -91,6 +90,7 @@ template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

View File

@@ -1,4 +1,3 @@
#ifndef _data_types_cuh
#define _data_types_cuh
#include <cuda_bf16.h>
@@ -7,7 +6,7 @@
#include "marlin.cuh"
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +0,0 @@
#pragma once
#include <Python.h>
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}

View File

@@ -3,8 +3,8 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \

View File

@@ -18,13 +18,12 @@
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \

View File

@@ -23,7 +23,6 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "kernel.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel(
int size_m,
int size_k,
int top_k) {};
} // namespace marlin
}
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,

View File

@@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
@@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;

View File

@@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_shuffle,
int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/*
* From csrc/moe
*/
@@ -340,15 +374,6 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
namespace marlin_moe_wna16 {
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
} // namespace marlin_moe_wna16
/*
* From csrc/speculative
*/

View File

@@ -44,6 +44,9 @@ from sgl_kernel.gemm import (
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
gptq_gemm,
gptq_marlin_gemm,
gptq_shuffle,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,

View File

@@ -2,6 +2,7 @@ import functools
from typing import Optional
import torch
from sgl_kernel import silu_and_mul
def get_scalar_type(num_bits: int, has_zp: bool):
@@ -165,7 +166,7 @@ def fused_marlin_moe(
is_zp_float=False,
)
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if expert_map is not None:
intermediate_cache3.zero_()

View File

@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant(
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# GPTQ kernels
def gptq_marlin_gemm(
a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_gemm(
a,
c,
b_q_weight,
b_scales,
global_scale,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
)
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)

View File

@@ -7,8 +7,8 @@ def gptq_marlin_repack(
size_k,
size_n,
num_bits,
):
torch.ops.sgl_kernel.gptq_marlin_repack.default(
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight,
perm,
size_k,

View File

@@ -0,0 +1,131 @@
import pytest
import torch
from sgl_kernel import gptq_gemm
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
assert bit == 4, "Reference dequantization only supports 4-bit"
group_size = K // scales.shape[0]
pack_factor = 32 // bit
# unpack q_weight: (K//pack_factor, N) -> (K, N)
unpacked_q_weight = torch.empty(
q_weight.shape[0] * pack_factor,
q_weight.shape[1],
dtype=torch.uint8,
device=q_weight.device,
)
for i in range(pack_factor):
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
unpacked_q_zeros = torch.empty(
q_zeros.shape[0],
q_zeros.shape[1] * pack_factor,
dtype=torch.uint8,
device=q_zeros.device,
)
for i in range(pack_factor):
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
unpacked_q_zeros += 1
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
current_g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
)
scale_mat = scales[current_g_idx] # (K, N)
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
# dequant: weight * scale - scale_zeros
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
return dequantized_b.reshape(K, N)
def torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
):
K, N = a.shape[1], b_q_weight.shape[1]
b_dequant = torch_dequantize(
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
)
c = torch.matmul(a, b_dequant)
return c
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
b_fp = torch.randn(K, N, dtype=dtype, device=device)
assert K % group_size == 0, "K must be divisible by group_size"
num_groups = K // group_size
if use_shuffle:
return
else:
g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=device
)
b_shuffled = b_fp[g_idx]
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
scales = (b_max - b_min) / (2**bit - 1)
scales = scales.clamp(min=1e-6)
zeros_float = (-b_min / scales).round()
q_b = (
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
)
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
b_gptq_scales = scales.squeeze(1)
a = torch.randn(M, K, dtype=dtype, device=device)
c_ref = torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
c_out = gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
rtol = 4e-2
atol = 4e-2
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
print(
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
)
@pytest.mark.parametrize("M", [1, 8, 128])
@pytest.mark.parametrize("N", [2048, 4096])
@pytest.mark.parametrize("K", [2048, 4096])
@pytest.mark.parametrize("bit", [4])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("use_shuffle", [False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,121 @@
import pytest
import torch
from sgl_kernel import gptq_marlin_gemm
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
# uint4 for awq
# uint4b8 for gptq
@pytest.mark.parametrize("k_chunk", [128])
@pytest.mark.parametrize("n_chunk", [64, 256])
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("is_k_full", [False, True])
@pytest.mark.parametrize("use_atomic_add", [False, True])
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
if has_zp:
# AWQ style, unsigned + runtime zero-point
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
# GPTQ style, unsigned + symmetric bias
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace(w_ref.device)
# marlin gemm
output = gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# ref gemm
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
assert max_diff < 0.04
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])

View File

@@ -1,16 +1,32 @@
import numpy as np
import pytest
import torch
from sgl_kernel import awq_marlin_repack
from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
get_pack_factor,
gptq_quantize_weights,
pack_cols,
pack_rows,
quantize_weights,
sort_weights,
)
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
GPTQ_MARLIN_TILE = 16
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
def awq_pack(
@@ -35,70 +51,6 @@ def awq_pack(
return pack_cols(q_w, num_bits, size_k, size_n)
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
@pytest.mark.parametrize("group_size", [16, 32])
@@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
torch.testing.assert_close(out_gpu, q_w_marlin)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if size_k % group_size != 0:
pytest.skip("size_k must be divisible by group_size")
# Create input
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order
)
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
q_w_marlin_ref = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
)
# Run Marlin repack GPU kernel
q_w_marlin = gptq_marlin_repack(
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
)
torch.cuda.synchronize()
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
if __name__ == "__main__":
import subprocess