From 4d643f6c7a291c86de64a9e52eca526b2d99775d Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Thu, 22 May 2025 10:48:59 +0800 Subject: [PATCH] [1/2] Support Qserve (#6457) Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: sleepcoo --- sgl-kernel/CMakeLists.txt | 2 + .../benchmark/bench_qserve_w4a8_gemm.py | 198 +++++ sgl-kernel/csrc/common_extension.cc | 13 + .../csrc/gemm/qserve_w4a8_per_chn_gemm.cu | 710 ++++++++++++++++ .../csrc/gemm/qserve_w4a8_per_group_gemm.cu | 795 ++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 21 + sgl-kernel/python/sgl_kernel/__init__.py | 2 + sgl-kernel/python/sgl_kernel/gemm.py | 44 + .../tests/test_qserve_w4a8_per_chn_gemm.py | 118 +++ .../tests/test_qserve_w4a8_per_group_gemm.py | 183 ++++ 10 files changed, 2086 insertions(+) create mode 100644 sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py create mode 100644 sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu create mode 100644 sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu create mode 100644 sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py create mode 100644 sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index a2858d3ec..71f77d51b 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -203,6 +203,8 @@ set(SOURCES "csrc/gemm/per_tensor_quant_fp8.cu" "csrc/gemm/per_token_group_quant_8bit.cu" "csrc/gemm/per_token_quant_fp8.cu" + "csrc/gemm/qserve_w4a8_per_chn_gemm.cu" + "csrc/gemm/qserve_w4a8_per_group_gemm.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" diff --git a/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py new file mode 100644 index 000000000..18fa4869d --- /dev/null +++ b/sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py @@ -0,0 +1,198 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import ( + int8_scaled_mm, + qserve_w4a8_per_chn_gemm, + qserve_w4a8_per_group_gemm, +) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"], + line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"], + styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], + ylabel="ms", + plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + # For W8A8 + a = to_int8(torch.randn((M, K), device="cuda") * 5) + b = to_int8(torch.randn((N, K), device="cuda").t() * 5) + a_fp16 = a.to(torch.float16) + b_fp16 = b.to(torch.float16) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + + # For Qserve W4A8 per channel + a_qserve_chn = a + # two int4s pack into one int8 + b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5) + # b_qserve_chn = b.t().contiguous() + scale_a_qserve_chn = scale_a.to(torch.float16) + scale_b_qserve_chn = scale_b.to(torch.float16) + szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16) + a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16) + + # For Qserve W4A8 per group + group_size = 128 + assert K % group_size == 0, "K must be divisible by group_size" + a_qserve_group = a + # two int4s pack into one int8 + b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5) + # b_qserve_group = b.t().contiguous() + scale_a_qserve_group = scale_a.to(torch.float16) + scale_b_qserve_group = scale_b.to(torch.float16) + scale_i8_b_qserve_group = to_int8( + torch.randn((K // group_size, N), device="cuda", dtype=torch.float16) + ) + zero_i8_b_qserve_group = to_int8( + torch.randn((K // group_size, N), device="cuda", dtype=torch.float16) + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "FP16": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a_fp16, b_fp16), + quantiles=quantiles, + ) + if provider == "W8A8": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16), + quantiles=quantiles, + ) + if provider == "Qserve_W4A8_Per_Channel": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: qserve_w4a8_per_chn_gemm( + a_qserve_chn, + b_qserve_chn, + scale_b_qserve_chn, + scale_a_qserve_chn, + szero_b_qserve_chn, + a_sum_qserve_chn, + ), + quantiles=quantiles, + ) + if provider == "Qserve_W4A8_Per_Group": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: qserve_w4a8_per_group_gemm( + a_qserve_group, + b_qserve_group, + zero_i8_b_qserve_group, + scale_i8_b_qserve_group, + scale_b_qserve_group, + scale_a_qserve_group, + ), + quantiles=quantiles, + ) + + return ms, max_ms, min_ms + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, + show_plots=True, + save_path="bench_qserve_w4a8_gemm_res", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 649bf4297..d83944b56 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -265,6 +265,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"); m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace); + + /* + * From QServe + */ + m.def( + "qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, " + "Tensor _a_ssums, Tensor! _out_feats) -> ()"); + m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm); + + m.def( + "qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, " + "Tensor _ascales, Tensor! _out_feats) -> ()"); + m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu b/sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu new file mode 100644 index 000000000..79180737f --- /dev/null +++ b/sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu @@ -0,0 +1,710 @@ +// Implemented by Haotian Tang and Shang Yang. +// @article{lin2024qserve, +// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving}, +// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and +// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024} +// } +// @article{yang2025lserve, +// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention}, +// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and +// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025} +// } + +// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu + +#include +#include +#include +#include + +#include "utils.h" + +#define OP_M 16 +#define OP_N 8 +#define OP_K 32 +#define INTRIN_M 16 +#define INTRIN_N 16 +#define INTRIN_K 32 +#define WARP_SIZE 32 +#define SMEM_PAD_A 0 +#define SMEM_PAD_B 0 +#define PACK_SIZE 16 +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) +#define L2_CACHEHINT(size) ".L2::" #size "B" +#else +#define L2_CACHEHINT(size) +#endif + +#define KERNEL_LAUNCH_CODE \ + constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ + constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ + constexpr int kSmemByteSize = \ + ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \ + sizeof(int8_t); \ + if (kSmemByteSize >= 99 * 1024) { \ + printf( \ + "This kernel requires %d Bytes of shared memory, which exceeds " \ + "device limit.\n", \ + kSmemByteSize); \ + return; \ + } \ + int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ + int num_blocks_n = num_out_channels / CTA_N / 1; \ + const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ + const int tile_shift = 1 << log_tile; \ + dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \ + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ + auto kernel_func = dense_kernel0; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats, num_in_feats, num_out_channels, num_in_channels); + +template +__inline__ __host__ __device__ int get_log_tile(int n) { + if (N >= 8 && n >= 6) + return 3; + else if (N >= 4 && n >= 3) + return 2; + else if (N >= 2 && n >= 2) + return 1; + else + return 0; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) { + return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1))); +} + +__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) { + uint32_t smem_int_ptr; + + asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, " + "smem_ptr; }\n" + : "=r"(smem_int_ptr) + : "l"(ptr)); + + return smem_int_ptr; +} + +__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) { + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3]) + : "r"(addr)); +} + +__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) { + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3]) + : "r"(addr)); +} + +// function from lmdeploy +__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) { + const int cp_size = 16; + asm volatile("{" + " .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" + "}" ::"r"((int)mask), + "r"(smem_int_ptr), + "l"(src), + "n"(cp_size)); +} + +__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3]) + : "r"(((unsigned*)A_shared_warp)[0]), + "r"(((unsigned*)A_shared_warp)[1]), + "r"(((unsigned*)A_shared_warp)[2]), + "r"(((unsigned*)A_shared_warp)[3]), + "r"(((unsigned*)B_shared_warp)[0]), + "r"(((unsigned*)B_shared_warp)[1]), + "r"(((int*)C_warp)[0]), + "r"(((int*)C_warp)[1]), + "r"(((int*)C_warp)[2]), + "r"(((int*)C_warp)[3])); +} + +template +__device__ __inline__ void global_to_share_one_stage_A( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask, + bool* preds) { + constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + int8_t* dst_hoisted = dst; + int8_t* src_hoisted = src + global_iter_k * CTA_K; + + if (mask) { +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + + void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol); + uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols); + // *dst_ptr = *src_ptr; + if constexpr (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, preds[global_iter]); + } else { + if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr; + } + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_B( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask) { + constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE; + constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE; + constexpr int warps_per_row = CTA_K / 32; + constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row; + constexpr int kSmemCol = CTA_K; + int8_t* dst_hoisted = dst; + int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; + +#pragma unroll + for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) { + void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE); + uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE); + if constexpr (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, mask); + } else { + if (mask) *(uint4*)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_zeros( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask) { + constexpr int threads_needed = CTA_N / PACK_SIZE / 1; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used; + constexpr int threads_per_row = CTA_N / PACK_SIZE; + constexpr int kSmemCol = CTA_N; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int g_idx = global_iter_k * CTA_K / G; + + void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); + if (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + } else { + if (local_mask) { + *(uint4*)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void +share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) { + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE; + + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { + int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16); + int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3; + void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE); + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_B( + int8_t* src, + int8_t* dst, + int8_t* zeros, + int8_t* scales_i8, + int warp_offset_m, + int warp_offset_n, + int k_0_0, + int k_0_1, + int shared_iters) { + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { + uint4 loaded = + *((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K + + threadIdx.x); + + auto ptr = (uint32_t*)dst + shared_iter * 8; + ptr[0] = loaded.x & 0x0F0F0F0F; + ptr[4] = (loaded.x & 0xF0F0F0F0) >> 4; + ptr[2] = loaded.y & 0x0F0F0F0F; + ptr[6] = (loaded.y & 0xF0F0F0F0) >> 4; + ptr[1] = loaded.z & 0x0F0F0F0F; + ptr[5] = (loaded.z & 0xF0F0F0F0) >> 4; + ptr[3] = loaded.w & 0x0F0F0F0F; + ptr[7] = (loaded.w & 0xF0F0F0F0) >> 4; + } +} + +template +__global__ void dense_kernel0( + int8_t* __restrict__ A, + int8_t* __restrict__ B, + half2* __restrict__ wscales, + half* __restrict__ ascales, + half2* __restrict__ w_szs, + half* __restrict__ a_ssums, + half* __restrict__ C, + int M, + int64_t N, + int64_t K) { + constexpr int SPLITK = 1; + constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; + constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; + constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; + constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE; + constexpr int SLICES = CTA_K / WARP_K; + int num_blocks_n = (N + CTA_N - 1) / CTA_N; + int num_blocks_m = (M + CTA_M - 1) / CTA_M; + + int blockIdx_n = blockIdx.x; + int blockIdx_m = blockIdx.y; + const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M); + const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile); + blockIdx_n = block_idx_mapping.x; + blockIdx_m = block_idx_mapping.y; + + int C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; + constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; + constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; + constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; + constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2; + constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; + constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; + + constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1; + constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1; + constexpr int kSmemSizeScales = CTA_N * STAGES; + + extern __shared__ int8_t mem_shared[]; + int8_t* A_shared = mem_shared; + + int8_t* B_shared = mem_shared + kSmemSizeA; + int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB; + int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + + int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE]; + int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE]; + constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K; + constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int A_threads_per_row = CTA_K / PACK_SIZE; + + constexpr int B_warps_per_row = CTA_K / 32; + constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row; + + int cta_offset_m = blockIdx_m * CTA_M; + int cta_offset_n = blockIdx_n * CTA_N; + int warp_mn = threadIdx.y % NUM_WARPS_MN; + int slice_id = threadIdx.y / NUM_WARPS_MN; + int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M; + int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N; + int warp_offset_k = slice_id * WARP_K; + + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) + C_warp[i] = 0; + + int gemm_iters = (K + CTA_K - 1) / CTA_K; + int k_0_0_ld = 0; + int k_0_0 = 0; + constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; + int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row); + int A_hoisted_col = (threadIdx.x % A_threads_per_row); + int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3; + + int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE; + int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + + (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE; + int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE; + int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + + (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE; + + bool A_g2s_preds[A_total_global_iters]; +#pragma unroll + for (int i = 0; i < A_total_global_iters; i++) { + A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M; + } + + int* C_shared = reinterpret_cast(mem_shared); + +#pragma unroll + for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) { + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + 0, + true, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); + + if constexpr (STAGES > 1) __pipeline_commit(); + } + if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2); + __syncthreads(); + + share_to_reg_one_stage_A( + A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M); + share_to_reg_one_stage_B( + B_shared + warp_offset_k * PACK_SIZE, + B_shared_warp_[0], + zeros_shared, + scales_i8_shared, + warp_offset_m, + warp_offset_n, + 0, + 0, + WARP_N / 32); + constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; + + for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) { + int ld_stage = k_0_0_ld % STAGES; + int compute_stage = k_0_0 % STAGES; + int8_t* A_shared_this_compute_stage; + int8_t* B_shared_this_compute_stage; + int8_t* zeros_shared_this_compute_stage; + int8_t* scales_i8_shared_this_compute_stage; + + for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) { + A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k; + B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE; + zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N; + scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N; + + share_to_reg_one_stage_A( + A_shared_this_compute_stage, + A_shared_warp_[(iter_k + 1) % 2], + warp_offset_m, + warp_offset_n, + (iter_k + 1) % SHARED_K_ITERS, + WARP_M / INTRIN_M); + share_to_reg_one_stage_B( + B_shared_this_compute_stage, + B_shared_warp_[(iter_k + 1) % 2], + zeros_shared_this_compute_stage, + scales_i8_shared_this_compute_stage, + warp_offset_m, + warp_offset_n, + k_0_0 + (iter_k == SHARED_K_ITERS - 1), + (iter_k + 1) % SHARED_K_ITERS, + WARP_N / 32); + int8_t* A_shared_warp = A_shared_warp_[iter_k % 2]; + int8_t* B_shared_warp = B_shared_warp_[iter_k % 2]; + + for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { + for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { + mma_m16n8k32( + (void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8), + (void*)(A_shared_warp + i_0_3 * 16), + (void*)(B_shared_warp + j_0_4 * 16)); + mma_m16n8k32( + (void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4), + (void*)(A_shared_warp + i_0_3 * 16), + (void*)(B_shared_warp + j_0_4 * 16 + 8)); + } + } + + if (iter_k < SHARED_K_ITERS - 1) { + if constexpr (STAGES == 1) __syncthreads(); + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + ld_stage * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, + B_shared_hoisted + ld_stage * kSmemSizeBPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters); + } + + if (iter_k == SHARED_K_ITERS - 2) { + if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) { + __syncthreads(); + } + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + ld_stage * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k + 1, + k_0_0_ld < gemm_iters, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, + B_shared_hoisted + ld_stage * kSmemSizeBPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k + 1, + k_0_0_ld < gemm_iters); + if constexpr (STAGES > 1) { + __pipeline_commit(); + __pipeline_wait_prior(STAGES - 2); + } + compute_stage = (k_0_0 + 1) % STAGES; + __syncthreads(); + } + } + } + + __pipeline_commit(); + __pipeline_wait_prior(0); + __syncthreads(); + + if constexpr (SLICES > 1) { +#pragma unroll + for (int z = 0; z < SLICES; ++z) { + if (slice_id == z) { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { + if (z > 0) { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2]; + } + C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + }; + } + } + } + __syncthreads(); + } + if (slice_id == 0) { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2]; + }; + } + } + } + } + + int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4); + int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2; + if (slice_id == 0) { + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { + int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M; + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { + int col_wb_1 = col_wb_thd + ax1_0_1 * 16; + int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8; + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { + int row_wb = row_wb_1 + (local_id % 4) / 2 * 8; + if (row_wb < M) { + int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); + float2 wscale = __half22float2(*(wscales + col_wb / 2)); + float2 w_sz = __half22float2(*(w_szs + col_wb / 2)); + float ascale = __half2float(ascales[row_wb]); + float a_ssum = __half2float(a_ssums[row_wb]); + float2 psums = + make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); + psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum; + psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum; + *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); + } + }; + } + } + } +} +#else +template +__global__ void dense_kernel0( + int8_t* __restrict__ A, + int8_t* __restrict__ B, + half2* __restrict__ wscales, + half* __restrict__ ascales, + half2* __restrict__ w_szs, + half* __restrict__ a_ssums, + half* __restrict__ C, + int M, + int64_t N, + int64_t K) { + // Not implemented for SM < 800 + assert(false); + return; +} +#endif + +void qserve_w4a8_per_chn_gemm( + const torch::Tensor& _in_feats, + const torch::Tensor& _kernel, + const torch::Tensor& _wscales, + const torch::Tensor& _ascales, + const torch::Tensor& _w_szs, + const torch::Tensor& _a_ssums, + torch::Tensor& _out_feats) { + // Check input tensor + TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor"); + TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor"); + TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous"); + TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8"); + // Check kernel tensor + TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor"); + TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor"); + TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous"); + TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8"); + // Check output tensor + TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor"); + TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous"); + TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half"); + + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + // Check matmul shape + TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)"); + TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats"); + + // Check _ascales + TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor"); + TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous"); + TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half"); + TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements"); + + // Check _wscales + TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor"); + TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous"); + TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half"); + TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements"); + + // Check _w_szs + TORCH_CHECK(_w_szs.is_cuda(), "_w_szs must be a CUDA tensor"); + TORCH_CHECK(_w_szs.is_contiguous(), "_w_szs must be contiguous"); + TORCH_CHECK(_w_szs.scalar_type() == torch::kHalf, "_w_szs must be half"); + TORCH_CHECK(_w_szs.numel() == num_out_channels, "_w_szs must have num_out_channels elements"); + + // Check _a_ssums + TORCH_CHECK(_a_ssums.is_cuda(), "_a_ssums must be a CUDA tensor"); + TORCH_CHECK(_a_ssums.is_contiguous(), "_a_ssums must be contiguous"); + TORCH_CHECK(_a_ssums.scalar_type() == torch::kHalf, "_a_ssums must be half"); + TORCH_CHECK(_a_ssums.numel() == num_in_feats, "_a_ssums must have num_in_feats elements"); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto w_szs = reinterpret_cast(_w_szs.data_ptr()); + auto a_ssums = reinterpret_cast(_a_ssums.data_ptr()); + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device()); + + auto sm_version = getSMVersion(); + if (sm_version >= 80) { + constexpr int G = 128; + + if (num_out_feats > 256) { + constexpr int CTA_M = 128; + constexpr int CTA_N = 128; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } else if (num_out_feats >= 128) { + constexpr int CTA_M = 64; + constexpr int CTA_N = 64; + constexpr int CTA_K = 64; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } else { + constexpr int CTA_M = 32; + constexpr int CTA_N = 64; + constexpr int CTA_K = 128; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } + } else { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No implemented qserve_w4a8_per_chn_gemm for current compute capability: ", sm_version); + } + return; +} diff --git a/sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu b/sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu new file mode 100644 index 000000000..a99a203e8 --- /dev/null +++ b/sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu @@ -0,0 +1,795 @@ +// Implemented by Haotian Tang and Shang Yang. +// @article{lin2024qserve, +// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving}, +// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and +// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024} +// } +// @article{yang2025lserve, +// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention}, +// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and +// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025} +// } + +// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu + +#include +#include +#include +#include + +#include "utils.h" + +#define OP_M 16 +#define OP_N 8 +#define OP_K 32 +#define INTRIN_M 16 +#define INTRIN_N 16 +#define INTRIN_K 32 +#define WARP_SIZE 32 +#define SMEM_PAD_A 0 +#define SMEM_PAD_B 0 +#define PACK_SIZE 16 +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) +#define L2_CACHEHINT(size) ".L2::" #size "B" +#else +#define L2_CACHEHINT(size) +#endif + +#define KERNEL_LAUNCH_CODE \ + constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ + constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ + constexpr int kSmemByteSize = \ + ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \ + sizeof(int8_t); \ + if (kSmemByteSize >= 99 * 1024) { \ + printf( \ + "This kernel requires %d Bytes of shared memory, which exceeds " \ + "device limit.\n", \ + kSmemByteSize); \ + return; \ + } \ + int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ + int num_blocks_n = num_out_channels / CTA_N / 1; \ + const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ + const int tile_shift = 1 << log_tile; \ + dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \ + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ + auto kernel_func = dense_kernel0; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, \ + kernel, \ + zeros, \ + scales_i8, \ + wscales, \ + ascales, \ + out_feats, \ + num_in_feats, \ + num_out_channels, \ + num_in_channels); + +template +__inline__ __host__ __device__ int get_log_tile(int n) { + if (N >= 8 && n >= 6) + return 3; + else if (N >= 4 && n >= 3) + return 2; + else if (N >= 2 && n >= 2) + return 1; + else + return 0; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) { + return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1))); +} + +__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) { + uint32_t smem_int_ptr; + + asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, " + "smem_ptr; }\n" + : "=r"(smem_int_ptr) + : "l"(ptr)); + + return smem_int_ptr; +} + +__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) { + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3]) + : "r"(addr)); +} + +__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) { + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]), + "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3]) + : "r"(addr)); +} + +// function from lmdeploy +__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) { + const int cp_size = 16; + asm volatile("{" + " .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" + "}" ::"r"((int)mask), + "r"(smem_int_ptr), + "l"(src), + "n"(cp_size)); +} + +__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3]) + : "r"(((unsigned*)A_shared_warp)[0]), + "r"(((unsigned*)A_shared_warp)[1]), + "r"(((unsigned*)A_shared_warp)[2]), + "r"(((unsigned*)A_shared_warp)[3]), + "r"(((unsigned*)B_shared_warp)[0]), + "r"(((unsigned*)B_shared_warp)[1]), + "r"(((int*)C_warp)[0]), + "r"(((int*)C_warp)[1]), + "r"(((int*)C_warp)[2]), + "r"(((int*)C_warp)[3])); +} + +template +__device__ __inline__ void global_to_share_one_stage_A( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask, + bool* preds) { + constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + int8_t* dst_hoisted = dst; + int8_t* src_hoisted = src + global_iter_k * CTA_K; + + if (mask) { +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol); + uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols); + if constexpr (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, preds[global_iter]); + } else { + if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr; + } + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_B( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask) { + constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE; + constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE; + constexpr int warps_per_row = CTA_K / 32; + constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row; + constexpr int kSmemCol = CTA_K; + int8_t* dst_hoisted = dst; + int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; + +#pragma unroll + for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) { + void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE); + uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE); + if constexpr (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, mask); + } else { + if (mask) *(uint4*)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_zeros( + int8_t* src, + int8_t* dst, + int global_ncols, + int cta_offset_m, + int cta_offset_n, + int global_iter_k, + int shared_iter_k, + bool mask) { + constexpr int threads_needed = CTA_N / PACK_SIZE / 1; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used; + constexpr int threads_per_row = CTA_N / PACK_SIZE; + constexpr int kSmemCol = CTA_N; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int g_idx = global_iter_k * CTA_K / G; + + void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); + if (STAGES > 1) { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + } else { + if (local_mask) { + *(uint4*)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void +share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) { + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE; + + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { + int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16); + int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3; + void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE); + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_B( + int8_t* src, + int8_t* dst, + int8_t* zeros, + int8_t* scales_i8, + int warp_offset_m, + int warp_offset_n, + int k_0_0, + int k_0_1, + int shared_iters) { + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { + uint4 loaded = + *((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K + + threadIdx.x); + uint32_t loaded_0 = loaded.x & 0x0F0F0F0F; + uint32_t loaded_4 = (loaded.x & 0xF0F0F0F0) >> 4; + uint32_t loaded_2 = loaded.y & 0x0F0F0F0F; + uint32_t loaded_6 = (loaded.y & 0xF0F0F0F0) >> 4; + uint32_t loaded_1 = loaded.z & 0x0F0F0F0F; + uint32_t loaded_5 = (loaded.z & 0xF0F0F0F0) >> 4; + uint32_t loaded_3 = loaded.w & 0x0F0F0F0F; + uint32_t loaded_7 = (loaded.w & 0xF0F0F0F0) >> 4; + + auto ptr = (uint32_t*)dst + shared_iter * 8; + int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32; + uint32_t packed_scales = *reinterpret_cast(scales_i8 + scales_zeros_offset); + uint32_t packed_zeros = *reinterpret_cast(zeros + scales_zeros_offset); + + uint32_t scale_0 = packed_scales & 0xFF; + uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000); + uint32_t ptr_0 = loaded_0 * scale_0; + uint32_t ptr_1 = loaded_1 * scale_0; + ptr[0] = __vadd4(ptr_0, zero_point_0); + ptr[1] = __vadd4(ptr_1, zero_point_0); + + uint32_t scale_1 = (packed_scales & 0xFF00) >> 8; + uint32_t zero_point_1 = __byte_perm(packed_zeros, 0, 0x00001111); + uint32_t ptr_2 = loaded_2 * scale_1; + uint32_t ptr_3 = loaded_3 * scale_1; + ptr[2] = __vadd4(ptr_2, zero_point_1); + ptr[3] = __vadd4(ptr_3, zero_point_1); + + uint32_t scale_2 = (packed_scales & 0xFF0000) >> 16; + uint32_t zero_point_2 = __byte_perm(packed_zeros, 0, 0x00002222); + uint32_t ptr_4 = loaded_4 * scale_2; + uint32_t ptr_5 = loaded_5 * scale_2; + ptr[4] = __vadd4(ptr_4, zero_point_2); + ptr[5] = __vadd4(ptr_5, zero_point_2); + + uint32_t scale_3 = (packed_scales & 0xFF000000) >> 24; + uint32_t zero_point_3 = __byte_perm(packed_zeros, 0, 0x00003333); + uint32_t ptr_6 = loaded_6 * scale_3; + uint32_t ptr_7 = loaded_7 * scale_3; + ptr[6] = __vadd4(ptr_6, zero_point_3); + ptr[7] = __vadd4(ptr_7, zero_point_3); + } +} + +template +__global__ void dense_kernel0( + int8_t* __restrict__ A, + int8_t* __restrict__ B, + int8_t* __restrict__ zeros, + int8_t* __restrict__ scales_i8, + half2* __restrict__ wscales, + half* __restrict__ ascales, + half* __restrict__ C, + int M, + int64_t N, + int64_t K) { + constexpr int SPLITK = 1; + constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; + constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; + constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; + constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE; + constexpr int SLICES = CTA_K / WARP_K; + int num_blocks_n = (N + CTA_N - 1) / CTA_N; + int num_blocks_m = (M + CTA_M - 1) / CTA_M; + + int blockIdx_n = blockIdx.x; + int blockIdx_m = blockIdx.y; + const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M); + const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile); + blockIdx_n = block_idx_mapping.x; + blockIdx_m = block_idx_mapping.y; + + int C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; + constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; + constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; + constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; + constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2; + constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; + constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; + + constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1; + constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1; + constexpr int kSmemSizeScales = CTA_N * STAGES; + + extern __shared__ int8_t mem_shared[]; + int8_t* A_shared = mem_shared; + + int8_t* B_shared = mem_shared + kSmemSizeA; + int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB; + int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + + int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE]; + int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE]; + constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE; + constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K; + constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int A_threads_per_row = CTA_K / PACK_SIZE; + + constexpr int B_warps_per_row = CTA_K / 32; + constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row; + + int cta_offset_m = blockIdx_m * CTA_M; + int cta_offset_n = blockIdx_n * CTA_N; + int warp_mn = threadIdx.y % NUM_WARPS_MN; + int slice_id = threadIdx.y / NUM_WARPS_MN; + int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M; + int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N; + int warp_offset_k = slice_id * WARP_K; + + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) + C_warp[i] = 0; + + int gemm_iters = (K + CTA_K - 1) / CTA_K; + + int k_0_0_ld = 0; + int k_0_0 = 0; + constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; + int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row); + int A_hoisted_col = (threadIdx.x % A_threads_per_row); + int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3; + + int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE; + int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + + (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE; + int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE; + int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + + (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE; + + bool A_g2s_preds[A_total_global_iters]; +#pragma unroll + for (int i = 0; i < A_total_global_iters; i++) { + A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M; + } + + int* C_shared = reinterpret_cast(mem_shared); + +#pragma unroll + for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) { + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + 0, + true, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); + global_to_share_one_stage_zeros( + zeros, zeros_shared + (k_0_0_ld)*CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters); + global_to_share_one_stage_zeros( + scales_i8, + scales_i8_shared + (k_0_0_ld)*CTA_N, + N, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + 0, + k_0_0_ld < gemm_iters); + + if constexpr (STAGES > 1) __pipeline_commit(); + } + if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2); + __syncthreads(); + + share_to_reg_one_stage_A( + A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M); + share_to_reg_one_stage_B( + B_shared + warp_offset_k * PACK_SIZE, + B_shared_warp_[0], + zeros_shared, + scales_i8_shared, + warp_offset_m, + warp_offset_n, + 0, + 0, + WARP_N / 32); + constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; + + for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) { + int ld_stage = k_0_0_ld % STAGES; + int compute_stage = k_0_0 % STAGES; + int8_t* A_shared_this_compute_stage; + int8_t* B_shared_this_compute_stage; + int8_t* zeros_shared_this_compute_stage; + int8_t* scales_i8_shared_this_compute_stage; + + for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) { + A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k; + B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE; + zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N; + scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N; + + share_to_reg_one_stage_A( + A_shared_this_compute_stage, + A_shared_warp_[(iter_k + 1) % 2], + warp_offset_m, + warp_offset_n, + (iter_k + 1) % SHARED_K_ITERS, + WARP_M / INTRIN_M); + share_to_reg_one_stage_B( + B_shared_this_compute_stage, + B_shared_warp_[(iter_k + 1) % 2], + zeros_shared_this_compute_stage, + scales_i8_shared_this_compute_stage, + warp_offset_m, + warp_offset_n, + k_0_0 + (iter_k == SHARED_K_ITERS - 1), + (iter_k + 1) % SHARED_K_ITERS, + WARP_N / 32); + int8_t* A_shared_warp = A_shared_warp_[iter_k % 2]; + int8_t* B_shared_warp = B_shared_warp_[iter_k % 2]; + + for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { + for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { + mma_m16n8k32( + (void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8), + (void*)(A_shared_warp + i_0_3 * 16), + (void*)(B_shared_warp + j_0_4 * 16)); + mma_m16n8k32( + (void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4), + (void*)(A_shared_warp + i_0_3 * 16), + (void*)(B_shared_warp + j_0_4 * 16 + 8)); + } + } + + if (iter_k < SHARED_K_ITERS - 1) { + if constexpr (STAGES == 1) __syncthreads(); + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + ld_stage * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, + B_shared_hoisted + ld_stage * kSmemSizeBPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters); + } + + if (iter_k == SHARED_K_ITERS - 2) { + if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) { + __syncthreads(); + } + global_to_share_one_stage_A( + A_hoisted, + A_shared_hoisted + ld_stage * kSmemSizeAPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k + 1, + k_0_0_ld < gemm_iters, + A_g2s_preds); + global_to_share_one_stage_B( + B_hoisted, + B_shared_hoisted + ld_stage * kSmemSizeBPerStage, + K, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k + 1, + k_0_0_ld < gemm_iters); + global_to_share_one_stage_zeros( + zeros, + zeros_shared + (ld_stage)*CTA_N, + N, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters); + global_to_share_one_stage_zeros( + scales_i8, + scales_i8_shared + (ld_stage)*CTA_N, + N, + cta_offset_m, + cta_offset_n, + k_0_0_ld, + iter_k, + k_0_0_ld < gemm_iters); + if constexpr (STAGES > 1) { + __pipeline_commit(); + __pipeline_wait_prior(STAGES - 2); + } + compute_stage = (k_0_0 + 1) % STAGES; + __syncthreads(); + } + } + } + __pipeline_commit(); + __pipeline_wait_prior(0); + __syncthreads(); + + if constexpr (SLICES > 1) { +#pragma unroll + for (int z = 0; z < SLICES; ++z) { + if (slice_id == z) { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { + if (z > 0) { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2]; + } + C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + }; + } + } + } + __syncthreads(); + } + if (slice_id == 0) { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared + [warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + + (threadIdx.x % 4) * 2]; + }; + } + } + } + } + + int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4); + int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2; + if (slice_id == 0) { + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { + int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M; + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { + int col_wb_1 = col_wb_thd + ax1_0_1 * 16; + int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8; + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { + int row_wb = row_wb_1 + (local_id % 4) / 2 * 8; + if (row_wb < M) { + int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); + float2 wscale = __half22float2(*(wscales + col_wb / 2)); + float ascale = __half2float(ascales[row_wb]); + float2 psums = + make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); + psums.x *= wscale.x * ascale; + psums.y *= wscale.y * ascale; + *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); + } + }; + } + } + } +} +#else +template +__global__ void dense_kernel0( + int8_t* __restrict__ A, + int8_t* __restrict__ B, + int8_t* __restrict__ zeros, + int8_t* __restrict__ scales_i8, + half2* __restrict__ wscales, + half* __restrict__ ascales, + half* __restrict__ C, + int M, + int64_t N, + int64_t K) { + // Not implemented for SM < 800 + assert(false); + return; +} +#endif + +void qserve_w4a8_per_group_gemm( + const torch::Tensor& _in_feats, + const torch::Tensor& _kernel, + const torch::Tensor& _zeros, + const torch::Tensor& _scales_i8, + const torch::Tensor& _wscales, + const torch::Tensor& _ascales, + torch::Tensor& _out_feats) { + // Check input tensor + TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor"); + TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor"); + TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous"); + TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8"); + // Check kernel tensor + TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor"); + TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor"); + TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous"); + TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8"); + // Check output tensor + TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor"); + TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous"); + TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half"); + + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + // Check matmul shape + TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)"); + TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats"); + + // Check _ascales + TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor"); + TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous"); + TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half"); + TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements"); + + // Check _wscales + TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor"); + TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous"); + TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half"); + TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements"); + + // Check _scales_i8 + TORCH_CHECK(_scales_i8.is_cuda(), "_scales_i8 must be a CUDA tensor"); + TORCH_CHECK(_scales_i8.dim() == 2, "_scales_i8 must be a 2D tensor"); + TORCH_CHECK(_scales_i8.is_contiguous(), "_scales_i8 must be contiguous"); + TORCH_CHECK(_scales_i8.scalar_type() == torch::kInt8, "_scales_i8 must be int8"); + TORCH_CHECK(num_in_channels % _scales_i8.size(0) == 0, "num_in_channels must be divisible by _scales_i8.size(0)"); + TORCH_CHECK(num_out_channels == _scales_i8.size(1), "num_out_channels must be equal to _scales_i8.size(1)"); + + // Check _zeros + TORCH_CHECK(_zeros.is_cuda(), "_zeros must be a CUDA tensor"); + TORCH_CHECK(_zeros.dim() == 2, "_zeros must be a 2D tensor"); + TORCH_CHECK(_zeros.is_contiguous(), "_zeros must be contiguous"); + TORCH_CHECK(_zeros.scalar_type() == torch::kInt8, "_zeros must be int8"); + TORCH_CHECK(num_in_channels % _zeros.size(0) == 0, "num_in_channels must be divisible by _zeros.size(0)"); + TORCH_CHECK(num_out_channels == _zeros.size(1), "num_out_channels must be equal to _zeros.size(1)"); + + // Check group size + auto group_size = num_in_channels / _scales_i8.size(0); + TORCH_CHECK(group_size == 128, "group_size must be 128"); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto scales_i8 = reinterpret_cast(_scales_i8.data_ptr()); + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + // auto options = + // torch::TensorOptions().dtype(torch::kHalf).device(_in_feats.device()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device()); + auto sm_version = getSMVersion(); + if (sm_version >= 80) { + constexpr int G = 128; + + if (num_out_feats > 128) { + constexpr int CTA_M = 128; + constexpr int CTA_N = 64; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } else if (num_out_feats >= 128) { + if (num_in_channels <= 4096) { + constexpr int CTA_M = 64; + constexpr int CTA_N = 64; + constexpr int CTA_K = 64; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } else { + constexpr int CTA_M = 64; + constexpr int CTA_N = 64; + constexpr int CTA_K = 128; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } + } else { + constexpr int CTA_M = 32; + constexpr int CTA_N = 64; + constexpr int CTA_K = 128; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } + } else { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No implemented qserve_w4a8_per_group_gemm for current compute capability: ", sm_version); + } + return; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 658f6950e..b5e376dc8 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -404,3 +404,24 @@ void convert_vertical_slash_indexes_mergehead( * From XGrammar */ void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt); + +/* + * From QServe + */ +void qserve_w4a8_per_chn_gemm( + const torch::Tensor& _in_feats, + const torch::Tensor& _kernel, + const torch::Tensor& _wscales, + const torch::Tensor& _ascales, + const torch::Tensor& _w_szs, + const torch::Tensor& _a_ssums, + torch::Tensor& _out_feats); + +void qserve_w4a8_per_group_gemm( + const torch::Tensor& _in_feats, + const torch::Tensor& _kernel, + const torch::Tensor& _zeros, + const torch::Tensor& _scales_i8, + const torch::Tensor& _wscales, + const torch::Tensor& _ascales, + torch::Tensor& _out_feats); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 70b5cdc77..ec97fa4b5 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -36,6 +36,8 @@ from sgl_kernel.gemm import ( fp8_blockwise_scaled_mm, fp8_scaled_mm, int8_scaled_mm, + qserve_w4a8_per_chn_gemm, + qserve_w4a8_per_group_gemm, scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 7035519c2..113542436 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -197,3 +197,47 @@ def scaled_fp4_quant( ) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale + + +def qserve_w4a8_per_chn_gemm( + in_feats: torch.Tensor, + kernel: torch.Tensor, + wscales: torch.Tensor, + ascales: torch.Tensor, + w_szs: torch.Tensor, + a_ssums: torch.Tensor, + out_feats: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out_feats is None: + # NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now + out_feats = torch.empty( + (in_feats.shape[0], kernel.shape[0]), + device=in_feats.device, + dtype=torch.float16, + ) + torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default( + in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats + ) + return out_feats + + +def qserve_w4a8_per_group_gemm( + in_feats: torch.Tensor, + kernel: torch.Tensor, + zeros: torch.Tensor, + scales_i8: torch.Tensor, + wscales: torch.Tensor, + ascales: torch.Tensor, + out_feats: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out_feats is None: + # NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now + out_feats = torch.empty( + (in_feats.shape[0], kernel.shape[0]), + device=in_feats.device, + dtype=torch.float16, + ) + torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default( + in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats + ) + return out_feats diff --git a/sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py b/sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py new file mode 100644 index 000000000..9410710d7 --- /dev/null +++ b/sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py @@ -0,0 +1,118 @@ +import pytest +import torch +from sgl_kernel import qserve_w4a8_per_chn_gemm + + +# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py +def convert_to_qserve_format(qweight, scale, zero): + assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range" + in_features = qweight.shape[1] + out_features = qweight.shape[0] + assert in_features % 32 == 0, "Input features must be divisible by 32" + assert out_features % 32 == 0, "Output features must be divisible by 32" + + # ---- Repack the weight ---- # + # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4) + qweight_unpack_reorder = ( + qweight.reshape( + out_features // 32, + 2, + 2, + 8, + in_features // 32, + 2, + 4, + 4, + ) + .permute(0, 4, 3, 6, 1, 5, 2, 7) + .contiguous() + ) + qweight_unpack_reorder = ( + qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4) + .contiguous() + .to(torch.int8) + ) + # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous() + # [16, 0, 17, 1, ...] + qweight_unpack_repacked = ( + qweight_unpack_reorder[..., 1] << 4 + ) + qweight_unpack_reorder[..., 0] + qweight_unpack_repacked = qweight_unpack_repacked.reshape( + out_features // 32, in_features // 32, 32, 16 + ) + qweight_unpack_repacked = qweight_unpack_repacked.reshape( + out_features, in_features // 2 + ).contiguous() + + # ---- Pack the scales ---- # + scale = scale.reshape(out_features).to(torch.float16).contiguous() + szero = zero.reshape(out_features).to(torch.float16).contiguous() * scale + + return qweight_unpack_repacked, scale, szero + + +# INT4 Quantization +def asym_quantize_tensor(tensor): + tensor_min = tensor.min(dim=-1, keepdim=True)[0] + tensor_max = tensor.max(dim=-1, keepdim=True)[0] + q_min = 0 + q_max = 15 + tensor_scale = (tensor_max - tensor_min) / (q_max - q_min) + tensor_zero = q_min - torch.round(tensor_min / tensor_scale) + tensor_q = torch.clamp( + torch.round(tensor / tensor_scale) + tensor_zero, q_min, q_max + ).to(torch.int8) + return tensor_q, tensor_scale.to(torch.float16), tensor_zero.to(torch.int8) + + +# INT8 Quantization +def sym_quantize_tensor(tensor): + tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127 + tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8) + return tensor_q, tensor_scale.to(torch.float16) + + +def torch_w4a8_per_chn_gemm(a, b, a_scale, b_scale, b_zero, out_dtype): + print(a.shape) + print(b.shape) + print(b_zero.shape) + o = torch.matmul( + a.to(torch.float16), (b.to(torch.float16) - b_zero.to(torch.float16)).t() + ) + o = o * a_scale.view(-1, 1) * b_scale.view(1, -1) + return o.to(out_dtype) + + +def _test_accuracy_once(M, N, K, out_dtype, device): + # to avoid overflow, multiply 0.01 + a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01 + b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01 + + # symmetric quantize a + a_q, a_scale = sym_quantize_tensor(a) + # asymmetric quantize b + b_q, b_scale, b_zero = asym_quantize_tensor(b) + # convert to qserve format + b_q_format, b_scale_format, b_szero_format = convert_to_qserve_format( + b_q, b_scale, b_zero + ) + + # cal sum of every row of a + a_sum = a.sum(dim=-1, keepdim=True).to(torch.float16) + out = qserve_w4a8_per_chn_gemm( + a_q, b_q_format, b_scale_format, a_scale, b_szero_format, a_sum + ) + ref_out = torch_w4a8_per_chn_gemm(a_q, b_q, a_scale, b_scale, b_zero, out_dtype) + torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-2) + + +@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) +@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("out_dtype", [torch.float16]) +def test_accuracy(M, N, K, out_dtype): + _test_accuracy_once(M, N, K, out_dtype, "cuda") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py b/sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py new file mode 100644 index 000000000..fc26a2e60 --- /dev/null +++ b/sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py @@ -0,0 +1,183 @@ +import pytest +import torch +from sgl_kernel import qserve_w4a8_per_group_gemm + + +# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py +def convert_to_qserve_format(qweight, chn_scale, scale_i8, zero_i8, group_size): + assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range" + in_features = qweight.shape[1] + out_features = qweight.shape[0] + assert in_features % 32 == 0, "Input features must be divisible by 32" + assert out_features % 32 == 0, "Output features must be divisible by 32" + assert group_size == 128, "Group size must be 128" + assert ( + in_features % group_size == 0 + ), "Input features must be divisible by group_size" + + # ---- Repack the weight ---- # + # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4) + qweight_unpack_reorder = ( + qweight.reshape( + out_features // 32, + 2, + 2, + 8, + in_features // 32, + 2, + 4, + 4, + ) + .permute(0, 4, 3, 6, 1, 5, 2, 7) + .contiguous() + ) + qweight_unpack_reorder = ( + qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4) + .contiguous() + .to(torch.int8) + ) + # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous() + # [16, 0, 17, 1, ...] + qweigth_unpack_repacked = ( + qweight_unpack_reorder[..., 1] << 4 + ) + qweight_unpack_reorder[..., 0] + qweigth_unpack_repacked = qweigth_unpack_repacked.reshape( + out_features // 32, in_features // 32, 32, 16 + ) + qweigth_unpack_repacked = qweigth_unpack_repacked.reshape( + out_features, in_features // 2 + ) + + # ---- Pack the scales ---- # + chn_scale = chn_scale.reshape(out_features) + + scale_i8 = ( + scale_i8.reshape(out_features, in_features // group_size) + .transpose(0, 1) + .contiguous() + ) + scale_i8 = scale_i8.reshape(in_features // group_size, out_features // 32, 32) + scale_i8 = ( + scale_i8.reshape(in_features // group_size, out_features // 32, 4, 8) + .transpose(-2, -1) + .contiguous() + ) + scale_i8 = scale_i8.reshape(in_features // group_size, out_features).contiguous() + + # ---- Pack the zeros ---- # + zero_i8 = -zero_i8 + # zero_i8 = zero_i8.int() # convert to 2-complement + + zero_i8 = ( + zero_i8.reshape(out_features, in_features // group_size) + .transpose(0, 1) + .contiguous() + ) + zero_i8 = zero_i8.reshape(in_features // group_size, out_features // 32, 32) + # for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm + zero_i8 = ( + zero_i8.reshape(in_features // group_size, out_features // 32, 4, 8) + .transpose(-2, -1) + .contiguous() + ) + zero_i8 = ( + zero_i8.reshape(in_features // group_size, out_features).contiguous() * scale_i8 + ) + + return qweigth_unpack_repacked, chn_scale, scale_i8, zero_i8 + + +# Progressive Group INT4 Quantization +def progressive_group_quantize_tensor(tensor, group_size): + assert group_size == 128, "Group size must be 128" + assert ( + tensor.shape[-1] % group_size == 0 + ), "Input features must be divisible by group_size" + # Channel scale + # NOTE(HandH1998): use protective quantization range + chn_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 119 + tensor_i8 = torch.clamp(torch.round(tensor / chn_scale), -119, 119) + + # Group scale + tensor_i8 = tensor_i8.reshape(-1, group_size) + tensor_i8_min = tensor_i8.min(dim=-1, keepdim=True)[0] + tensor_i8_max = tensor_i8.max(dim=-1, keepdim=True)[0] + q_min = 0 + q_max = 15 + scale_i8 = torch.round((tensor_i8_max - tensor_i8_min) / (q_max - q_min)) + zero_i8 = q_min - torch.round(tensor_i8_min / scale_i8) + tensor_q = ( + torch.clamp(torch.round(tensor_i8 / scale_i8) + zero_i8, q_min, q_max) + .reshape(tensor.shape[0], -1) + .to(torch.int8) + ) + return ( + tensor_q, + chn_scale.to(torch.float16), + scale_i8.reshape(tensor.shape[0], -1).to(torch.int8), + zero_i8.reshape(tensor.shape[0], -1).to(torch.int8), + ) + + +# INT8 Quantization +def sym_quantize_tensor(tensor): + tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127 + tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8) + return tensor_q, tensor_scale.to(torch.float16) + + +def torch_w4a8_per_group_gemm( + a, b, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype +): + assert group_size == 128, "Group size must be 128" + b_dq = ( + b.reshape(-1, group_size).to(torch.float32) + - b_zero_i8.reshape(-1, 1).to(torch.float32) + ) * b_scale_i8.reshape(-1, 1).to(torch.float32) + b_dq = b_dq.reshape(b.shape[0], b.shape[1]) + o = torch.matmul(a.to(torch.float32), b_dq.t()) + o = o * a_scale.view(-1, 1) * b_chn_scale.view(1, -1) + return o.to(out_dtype) + + +def _test_accuracy_once(M, N, K, group_size, out_dtype, device): + # to avoid overflow, multiply 0.01 + a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01 + b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01 + + # symmetric quantize a + a_q, a_scale = sym_quantize_tensor(a) + # asymmetric quantize b + b_q, b_chn_scale, b_scale_i8, b_zero_i8 = progressive_group_quantize_tensor( + b, group_size + ) + # convert to qserve format + b_q_format, b_chn_scale_format, b_scale_i8_format, b_zero_i8_format = ( + convert_to_qserve_format(b_q, b_chn_scale, b_scale_i8, b_zero_i8, group_size) + ) + + out = qserve_w4a8_per_group_gemm( + a_q, + b_q_format, + b_zero_i8_format, + b_scale_i8_format, + b_chn_scale_format, + a_scale, + ) + ref_out = torch_w4a8_per_group_gemm( + a_q, b_q, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype + ) + torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-5) + + +@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) +@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("group_size", [128]) +@pytest.mark.parametrize("out_dtype", [torch.float16]) +def test_accuracy(M, N, K, group_size, out_dtype): + _test_accuracy_once(M, N, K, group_size, out_dtype, "cuda") + + +if __name__ == "__main__": + pytest.main([__file__])