[2/n]decouple quantization implementation from vLLM dependency (#8112)
Co-authored-by: walker-ai <yiyun.wyt@antgroup.com> Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
@@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
||||
|
||||
// GPTQ related method
|
||||
m.def(
|
||||
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
||||
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
|
||||
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
|
||||
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||
|
||||
m.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
|
||||
"use_shuffle, int bit) -> Tensor");
|
||||
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
62
sgl-kernel/csrc/gemm/gptq/compat.cuh
Normal file
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||
atomicAdd_half(address, val);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||
atomicAdd_half2(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
1950
sgl-kernel/csrc/gemm/gptq/gptq_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
269
sgl-kernel/csrc/gemm/gptq/matrix_view.cuh
Normal file
@@ -0,0 +1,269 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and
|
||||
https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half {
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||
return __half2half2(data[row * width + column]);
|
||||
}
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw {
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||
data[row * width + column] = value;
|
||||
}
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||
((half2*)data)[(row * width + column) / 2] = value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) {
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||
return data[row / 8 * width + column];
|
||||
}
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) {
|
||||
return &data[row / 8 * width + column];
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q2_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
74
sgl-kernel/csrc/gemm/gptq/qdq_2.cuh
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
146
sgl-kernel/csrc/gemm/gptq/qdq_3.cuh
Normal file
@@ -0,0 +1,146 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qa & 0x07;
|
||||
uint32_t t1 = (qa & 0x38) >> 3;
|
||||
qa >>= 6;
|
||||
za |= (t0 << (i * 3));
|
||||
za |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qb & 0x07;
|
||||
uint32_t t1 = (qb & 0x38) >> 3;
|
||||
qb >>= 6;
|
||||
zb |= (t0 << (i * 3));
|
||||
zb |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qc & 0x07;
|
||||
uint32_t t1 = (qc & 0x38) >> 3;
|
||||
qc >>= 6;
|
||||
zc |= (t0 << (i * 3));
|
||||
zc |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32(
|
||||
const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y8, z8);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y8, z8);
|
||||
dq[4] = __hfma2(q4.as_half2, y64, z64);
|
||||
dq[5] = __hadd2(q5.as_half2, z1);
|
||||
dq[6] = __hfma2(q6.as_half2, y8, z8);
|
||||
dq[7] = __hadd2(q7.as_half2, z1);
|
||||
dq[8] = __hfma2(q8.as_half2, y8, z8);
|
||||
dq[9] = __hfma2(q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
114
sgl-kernel/csrc/gemm/gptq/qdq_4.cuh
Normal file
@@ -0,0 +1,114 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled) {
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
} else {
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
30
sgl-kernel/csrc/gemm/gptq/qdq_8.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||
|
||||
__forceinline__ __device__ void
|
||||
dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) {
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++)
|
||||
dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
|
||||
#endif
|
||||
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
53
sgl-kernel/csrc/gemm/gptq/qdq_util.cuh
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace sglang {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32 {
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16 {
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) {
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) {
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) {
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace sglang
|
||||
#endif
|
||||
@@ -1,15 +1,12 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "kernel.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in awq_marlin_repack_kernel
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
@@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
@@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
@@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
459
sgl-kernel/csrc/gemm/marlin/dequant.h
Normal file
@@ -0,0 +1,459 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, sglang::ScalarTypeId w_type_id, bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), true>(int q, half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4B8.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU4.id(), false>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(
|
||||
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), true>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8B128.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), true>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kU8.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kU8.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), true>(int q, half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), false>(int q, half2* frag_b) {
|
||||
dequant<half2, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
1120
sgl-kernel/csrc/gemm/marlin/gptq_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,17 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
#include "marlin.cuh"
|
||||
|
||||
namespace marlin {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in gptq_marlin_repack_kernel
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
@@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
@@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
@@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
TORCH_CHECK(
|
||||
size_k % marlin::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
marlin::tile_k_size);
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
@@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
@@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(
|
||||
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
// }
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
// }
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
36
sgl-kernel/csrc/gemm/marlin/kernel.h
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
@@ -10,11 +10,10 @@
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
@@ -91,6 +90,7 @@ template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include <cuda_bf16.h>
|
||||
@@ -7,7 +6,7 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
1629
sgl-kernel/csrc/gemm/marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,25 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "gemm/marlin/marlin.cuh"
|
||||
#include "gemm/marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
|
||||
@@ -18,13 +18,12 @@
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "gemm/marlin/marlin.cuh"
|
||||
#include "gemm/marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "kernel.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel(
|
||||
int size_m,
|
||||
int size_k,
|
||||
int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
|
||||
Reference in New Issue
Block a user