Sync from v0.13
This commit is contained in:
2
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
2
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
288
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
288
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
@@ -0,0 +1,288 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool is_a_8bit>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
|
||||
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
|
||||
int k_tiles = size_k / target_tile_k_size;
|
||||
int n_tiles = size_n / target_tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = target_tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = target_tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * target_tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * target_tile_k_size;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
||||
first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (is_a_8bit) {
|
||||
int cur_elem = tc_row + i;
|
||||
|
||||
int packed_src_0 =
|
||||
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
|
||||
sh_stride * cur_elem];
|
||||
int packed_src_1 =
|
||||
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
|
||||
sh_stride * (cur_elem + 16)];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
} else {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 =
|
||||
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||
sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size =
|
||||
target_tile_k_size * target_tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (!is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else if constexpr (is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const int ii = is_a_8bit ? i : pack_idx[i];
|
||||
res1 |= vals[ii] << (i * 8);
|
||||
res2 |= vals[4 + ii] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, IS_A_8BIT) \
|
||||
else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
IS_A_8BIT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
IS_A_8BIT> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits,
|
||||
bool is_a_8bit) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
||||
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
" is not size_k = ", size_k);
|
||||
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
", is_a_8bit = ", is_a_8bit);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
||||
}
|
||||
609
csrc/quantization/gptq_marlin/dequant.h
Normal file
609
csrc/quantization/gptq_marlin/dequant.h
Normal file
@@ -0,0 +1,609 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg =
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg =
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
|
||||
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70707070;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note1: reverse indexing is intentional because weights are permuted
|
||||
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
|
||||
// `frag_b[1]` to fit the layout of tensorcore
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
|
||||
int q, int32_t* frag_b) {
|
||||
constexpr int repeated_zp = 0x08080808;
|
||||
constexpr int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
int s = q & 0x08080808;
|
||||
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
q >>= 4;
|
||||
s = q & 0x08080808;
|
||||
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
|
||||
int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
|
||||
// but we assume that such a extreme value would not occur in real models.
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
q <<= 7;
|
||||
int Out2 = q & 0x7F807F80;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
};
|
||||
|
||||
// subtract zero point in quanted format and then dequant
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
|
||||
int q, int32_t* frag_b, int zp) {
|
||||
// INT4 with zp -> INT8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
int repeated_zp = 0x01010101 * zp;
|
||||
int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
|
||||
true>(int q, __nv_fp8x4_e4m3* frag_b,
|
||||
int zp) {
|
||||
// INT4 with zp -> FP8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
|
||||
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
|
||||
uint32_t u_zp1 = u_zp + 1;
|
||||
uint32_t repeated_zp = 0x01010101 * u_zp;
|
||||
|
||||
uint32_t q0, s;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
u_q >>= 4;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
296
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
296
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
// clang-format off
|
||||
""".lstrip()
|
||||
|
||||
FILE_HEAD = (
|
||||
FILE_HEAD_COMMENT
|
||||
+ """
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
"""
|
||||
)
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{a_type_id}}, "
|
||||
"{{b_type_id}}, "
|
||||
"{{c_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{is_zp_float}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
|
||||
QUANT_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# HQQ
|
||||
{
|
||||
"a_type": ["kFloat16"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [4],
|
||||
"is_zp_float": True,
|
||||
},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": "kU8B128",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{
|
||||
"b_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 8],
|
||||
},
|
||||
# NVFP4
|
||||
{
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [1],
|
||||
},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kFE2M1f",
|
||||
"c_type": ["kBFloat16"],
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
filename = os.path.dirname(__file__) + "/kernel_selector.h"
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
|
||||
b_type = quant_config["b_type"]
|
||||
is_zp_float = quant_config.get("is_zp_float", False)
|
||||
all_group_blocks = quant_config["group_blocks"]
|
||||
all_m_blocks = quant_config["thread_m_blocks"]
|
||||
all_thread_configs = quant_config["thread_configs"]
|
||||
|
||||
for a_type, c_type in itertools.product(a_types, c_types):
|
||||
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
|
||||
continue
|
||||
if "16" in a_type and "16" in c_type and a_type != c_type:
|
||||
continue
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
):
|
||||
thread_k, thread_n, threads = thread_configs
|
||||
|
||||
if threads == 256:
|
||||
# for small batch (m_blocks == 1),
|
||||
# we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1),
|
||||
# we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
|
||||
continue
|
||||
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
|
||||
continue
|
||||
|
||||
config = {
|
||||
"threads": threads,
|
||||
"s_type": s_type,
|
||||
"thread_m_blocks": max(m_blocks, 1),
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
"else if (a_type == vllm::kFE4M3fn)\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"marlin kernel with fp8 activation is not built.");'
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
|
||||
f.write(kernel_selector_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
850
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
850
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
@@ -0,0 +1,850 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace marlin {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int lda, int block_rows) {}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int lda, int block_rows) {
|
||||
auto start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
finish_row = size_m;
|
||||
}
|
||||
int cur_block_rows = finish_row - start_row;
|
||||
|
||||
int input_row_stride = lda * sizeof(half) / 16;
|
||||
int output_row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int input_offset = row * input_row_stride;
|
||||
int output_offset = row * output_row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + input_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + output_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < cur_block_rows; i++) {
|
||||
int cur_row = start_row + i;
|
||||
if (cur_row < size_m) {
|
||||
permute_row(cur_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128},
|
||||
{128, 64, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128},
|
||||
{128, 64, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size =
|
||||
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel =
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
|
||||
void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
|
||||
int lda, void* workspace, vllm::ScalarType const& a_type,
|
||||
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
|
||||
vllm::ScalarType const& s_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = b_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
int block_rows = div_ceil(prob_m, sms);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
lda = prob_k;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
int major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
int max_par = 16;
|
||||
if (prob_n <= 4096) max_par = 16 * 8;
|
||||
int max_shared_mem_new = max_shared_mem;
|
||||
int rest_m = prob_m;
|
||||
int max_thread_m_blocks = 4;
|
||||
while (rest_m) {
|
||||
int par_count = rest_m / (max_thread_m_blocks * 16);
|
||||
if (par_count > max_par) par_count = max_par;
|
||||
int prob_m_split =
|
||||
par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m;
|
||||
|
||||
int thread_k = thread_k_init;
|
||||
int thread_n = thread_n_init;
|
||||
|
||||
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
|
||||
int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
|
||||
sms) {
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
|
||||
max_thread_m_blocks--;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m,
|
||||
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", prob_m_split = ", prob_m_split,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_threads = ", num_threads, ", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem_new);
|
||||
|
||||
bool part_use_atomic_add =
|
||||
use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048;
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
|
||||
g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
|
||||
a_s_ptr += prob_m_split;
|
||||
C_ptr += prob_m_split * (prob_n / 8);
|
||||
rest_m -= prob_m_split;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& a_scales_or_none,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
|
||||
|
||||
auto c_dtype = a.dtype();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
a_type_id = vllm::kFloat16.id();
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
a_type_id = vllm::kBFloat16.id();
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_dtype = b_scales.dtype();
|
||||
if (b_scales.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
|
||||
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
|
||||
torch::Tensor c = c_or_none.value();
|
||||
c_dtype = c.dtype();
|
||||
|
||||
if (c.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported c dtype");
|
||||
}
|
||||
}
|
||||
|
||||
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
a_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::Char) {
|
||||
a_type_id = vllm::kS8.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported `a` scalar_type");
|
||||
}
|
||||
}
|
||||
|
||||
s_type_id = c_type_id;
|
||||
if (b_type_id == vllm::kFE2M1f.id()) {
|
||||
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
s_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
|
||||
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
|
||||
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
|
||||
int pack_factor = 32 / b_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(1) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.stride(1) == 1, "A.stride(1) is not 1");
|
||||
// We use int4 (16 bytes) to load A, so A must aligned to 16 bytes
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0, "A.stride(0) must divisible by 8");
|
||||
TORCH_CHECK(((uint64_t)a.data_ptr()) % 16 == 0, "A must aligned to 16 bytes");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
torch::Tensor a_scales;
|
||||
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
|
||||
if (a_scales_or_none.has_value()) {
|
||||
a_scales = a_scales_or_none.value();
|
||||
TORCH_CHECK(a_type.size_bits() == 8,
|
||||
"a_scales can only be used for 8bit activation.");
|
||||
} else {
|
||||
a_scales = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(a_type.size_bits() != 8,
|
||||
"the a_scales parameter must be passed for 8bit activation.");
|
||||
}
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m, "Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m, size_n}, options);
|
||||
}
|
||||
if (size_m == 0) return c;
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
if (use_fp32_reduce) {
|
||||
int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
|
||||
max_m_block_size = min(max_m_block_size, 64);
|
||||
int max_c_tmp_size =
|
||||
sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(0);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_type == vllm::kU4 || b_type == vllm::kU8,
|
||||
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
|
||||
"b_type must be uint4b8, uint8b128, int4, int8, "
|
||||
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
|
||||
b_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(0),
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int min_workspace_size = sms;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
"scalar type of a must be the same with c for 16 bit activation");
|
||||
}
|
||||
|
||||
marlin::marlin_mm(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
|
||||
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
|
||||
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
}
|
||||
357
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
357
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
@@ -0,0 +1,357 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm,
|
||||
bool is_a_8bit>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
|
||||
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
|
||||
int k_tiles = size_k / target_tile_k_size;
|
||||
int n_tiles = size_n / target_tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = target_tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = target_tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = target_tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * target_tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(
|
||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * target_tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||
first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = target_tile_n_size;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
static_assert(!is_a_8bit);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
if constexpr (is_a_8bit) {
|
||||
b1_vals[i] =
|
||||
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
|
||||
} else {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
if constexpr (is_a_8bit)
|
||||
vals[4 + i] =
|
||||
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
|
||||
else
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size =
|
||||
target_tile_k_size * target_tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (!is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else if constexpr (is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const int ii = is_a_8bit ? i : pack_idx[i];
|
||||
res1 |= vals[ii] << (i * 8);
|
||||
res2 |= vals[4 + ii] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
|
||||
is_a_8bit == IS_A_8BIT) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM, IS_A_8BIT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM, IS_A_8BIT> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits, bool is_a_8bit) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", pack_factor = ", pack_factor);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false, false)
|
||||
CALL_IF(4, true, false)
|
||||
CALL_IF(8, false, false)
|
||||
CALL_IF(8, true, false)
|
||||
|
||||
CALL_IF(4, false, true)
|
||||
CALL_IF(8, false, true)
|
||||
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
}
|
||||
43
csrc/quantization/gptq_marlin/kernel.h
Normal file
43
csrc/quantization/gptq_marlin/kernel.h
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||
int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
|
||||
const vllm::ScalarTypeId b_type_id, // B ScalarType id
|
||||
const vllm::ScalarTypeId c_type_id, // C ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
131
csrc/quantization/gptq_marlin/marlin.cuh
Normal file
131
csrc/quantization/gptq_marlin/marlin.cuh
Normal file
@@ -0,0 +1,131 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages =
|
||||
4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 4;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 8;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
149
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
Normal file
149
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
Normal file
@@ -0,0 +1,149 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "marlin.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <long scalar_type_id>
|
||||
class MarlinScalarType {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kFloat16.id()> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
using scalar_t4 = half2;
|
||||
using scalar_32bit_t = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kBFloat16.id()> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
using scalar_t4 = nv_bfloat162;
|
||||
using scalar_32bit_t = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kFE4M3fn.id()> {
|
||||
public:
|
||||
using scalar_t = __nv_fp8_e4m3;
|
||||
using scalar_t2 = __nv_fp8x2_e4m3;
|
||||
using scalar_t4 = __nv_fp8x4_e4m3;
|
||||
using scalar_32bit_t = __nv_fp8x4_e4m3;
|
||||
|
||||
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
|
||||
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
|
||||
|
||||
static __host__ __device__
|
||||
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
|
||||
return (float2)x;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kS8.id()> {
|
||||
public:
|
||||
using scalar_t = int8_t;
|
||||
using scalar_t2 = int16_t;
|
||||
using scalar_t4 = int32_t;
|
||||
using scalar_32bit_t = int32_t;
|
||||
|
||||
using FragA = Vec<int32_t, 4>;
|
||||
using FragB = Vec<int32_t, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragZP = Vec<int16_t, 4>;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
class MarlinScalarType2 {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<nv_bfloat16>
|
||||
: public MarlinScalarType<vllm::kBFloat16.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<__nv_fp8_e4m3>
|
||||
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
106
csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu
Normal file
106
csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu
Normal file
@@ -0,0 +1,106 @@
|
||||
|
||||
|
||||
#include "marlin.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
// for only non-zp format (like gptq)
|
||||
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
|
||||
// qweight: (size_k * size_n // 8,)
|
||||
const int32_t* __restrict__ qweight,
|
||||
// output: same shape with qweight
|
||||
int32_t* __restrict__ output) {
|
||||
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
|
||||
int32_t new_val = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 8; i++) {
|
||||
int32_t single_val = val & 0xF;
|
||||
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
|
||||
new_val |= single_val << (i * 4);
|
||||
val >>= 4;
|
||||
}
|
||||
|
||||
output[blockIdx.x * 32 + threadIdx.x] = new_val;
|
||||
}
|
||||
|
||||
// for awq format only (with zp and with awq weight layout)
|
||||
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
|
||||
// AWQ qweight: (size_k, size_n // 8)
|
||||
const int32_t* __restrict__ qweight,
|
||||
// output: same shape with qweight
|
||||
int32_t* __restrict__ output,
|
||||
// AWQ zeros: (size_k // group_size, size_n // 8)
|
||||
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
|
||||
int32_t group_size) {
|
||||
int32_t val =
|
||||
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
|
||||
int32_t zero =
|
||||
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
|
||||
blockIdx.y];
|
||||
int32_t new_val = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 8; i++) {
|
||||
int32_t single_val = val & 0xF;
|
||||
int32_t single_zero = zero & 0xF;
|
||||
|
||||
single_val =
|
||||
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
|
||||
new_val |= single_val << (i * 4);
|
||||
val >>= 4;
|
||||
zero >>= 4;
|
||||
}
|
||||
|
||||
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
|
||||
}
|
||||
|
||||
torch::Tensor marlin_int4_fp8_preprocess(
|
||||
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
|
||||
bool inplace) {
|
||||
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
|
||||
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
|
||||
"qweight.dtype != torch.int32");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||
|
||||
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
|
||||
|
||||
if (!qzeros_or_none.has_value()) {
|
||||
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
|
||||
"qweight.numel() * 8 % 256 != 0");
|
||||
|
||||
int blocks = qweight.numel() * 8 / 256;
|
||||
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
|
||||
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
|
||||
} else {
|
||||
int32_t size_k = qweight.size(0);
|
||||
int32_t size_n = qweight.size(1) * 8;
|
||||
torch::Tensor qzeros = qzeros_or_none.value();
|
||||
|
||||
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
|
||||
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
|
||||
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
|
||||
"qweight.dtype != torch.int32");
|
||||
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
|
||||
"qzeros is not on the same device with qweight");
|
||||
|
||||
int32_t group_size = qweight.size(0) / qzeros.size(0);
|
||||
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
|
||||
"qweight.size(1) != qzeros.size(1)");
|
||||
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
|
||||
"qweight.size(0) % qzeros.size(0) != 0");
|
||||
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
|
||||
|
||||
dim3 blocks(size_k / 32, size_n / 8);
|
||||
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
|
||||
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
|
||||
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
|
||||
}
|
||||
2177
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
2177
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user