[1/2] Support Qserve (#6457)
Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -203,6 +203,8 @@ set(SOURCES
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu"
|
||||
"csrc/gemm/per_token_group_quant_8bit.cu"
|
||||
"csrc/gemm/per_token_quant_fp8.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
|
||||
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import (
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
# For W8A8
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
a_fp16 = a.to(torch.float16)
|
||||
b_fp16 = b.to(torch.float16)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# For Qserve W4A8 per channel
|
||||
a_qserve_chn = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_chn = b.t().contiguous()
|
||||
scale_a_qserve_chn = scale_a.to(torch.float16)
|
||||
scale_b_qserve_chn = scale_b.to(torch.float16)
|
||||
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
|
||||
|
||||
# For Qserve W4A8 per group
|
||||
group_size = 128
|
||||
assert K % group_size == 0, "K must be divisible by group_size"
|
||||
a_qserve_group = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_group = b.t().contiguous()
|
||||
scale_a_qserve_group = scale_a.to(torch.float16)
|
||||
scale_b_qserve_group = scale_b.to(torch.float16)
|
||||
scale_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
zero_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "FP16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a_fp16, b_fp16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "W8A8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Channel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_chn_gemm(
|
||||
a_qserve_chn,
|
||||
b_qserve_chn,
|
||||
scale_b_qserve_chn,
|
||||
scale_a_qserve_chn,
|
||||
szero_b_qserve_chn,
|
||||
a_sum_qserve_chn,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Group":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_group_gemm(
|
||||
a_qserve_group,
|
||||
b_qserve_group,
|
||||
zero_i8_b_qserve_group,
|
||||
scale_i8_b_qserve_group,
|
||||
scale_b_qserve_group,
|
||||
scale_a_qserve_group,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_qserve_w4a8_gemm_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
@@ -265,6 +265,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
|
||||
/*
|
||||
* From QServe
|
||||
*/
|
||||
m.def(
|
||||
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
|
||||
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
|
||||
|
||||
m.def(
|
||||
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
710
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
Normal file
710
sgl-kernel/csrc/gemm/qserve_w4a8_per_chn_gemm.cu
Normal file
@@ -0,0 +1,710 @@
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
// @article{yang2025lserve,
|
||||
// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention},
|
||||
// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and
|
||||
// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025}
|
||||
// }
|
||||
|
||||
// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize = \
|
||||
((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \
|
||||
sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) { \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>( \
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats, num_in_feats, num_out_channels, num_in_channels);
|
||||
|
||||
template <int N>
|
||||
__inline__ __host__ __device__ int get_log_tile(int n) {
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) {
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) {
|
||||
const int cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) {
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3])
|
||||
: "r"(((unsigned*)A_shared_warp)[0]),
|
||||
"r"(((unsigned*)A_shared_warp)[1]),
|
||||
"r"(((unsigned*)A_shared_warp)[2]),
|
||||
"r"(((unsigned*)A_shared_warp)[3]),
|
||||
"r"(((unsigned*)B_shared_warp)[0]),
|
||||
"r"(((unsigned*)B_shared_warp)[1]),
|
||||
"r"(((int*)C_warp)[0]),
|
||||
"r"(((int*)C_warp)[1]),
|
||||
"r"(((int*)C_warp)[2]),
|
||||
"r"(((int*)C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_A(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask,
|
||||
bool* preds) {
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
// *dst_ptr = *src_ptr;
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
} else {
|
||||
if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) {
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
} else {
|
||||
if (mask) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void global_to_share_one_stage_zeros(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
} else {
|
||||
if (local_mask) {
|
||||
*(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ __inline__ void
|
||||
share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void share_to_reg_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int8_t* zeros,
|
||||
int8_t* scales_i8,
|
||||
int warp_offset_m,
|
||||
int warp_offset_n,
|
||||
int k_0_0,
|
||||
int k_0_1,
|
||||
int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
uint4 loaded =
|
||||
*((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K +
|
||||
threadIdx.x);
|
||||
|
||||
auto ptr = (uint32_t*)dst + shared_iter * 8;
|
||||
ptr[0] = loaded.x & 0x0F0F0F0F;
|
||||
ptr[4] = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
ptr[2] = loaded.y & 0x0F0F0F0F;
|
||||
ptr[6] = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
ptr[1] = loaded.z & 0x0F0F0F0F;
|
||||
ptr[5] = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
ptr[3] = loaded.w & 0x0F0F0F0F;
|
||||
ptr[7] = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half2* __restrict__ w_szs,
|
||||
half* __restrict__ a_ssums,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++) {
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
|
||||
if constexpr (STAGES > 1) __pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0],
|
||||
zeros_shared,
|
||||
scales_i8_shared,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
0,
|
||||
0,
|
||||
WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2],
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2],
|
||||
zeros_shared_this_compute_stage,
|
||||
scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
k_0_0 + (iter_k == SHARED_K_ITERS - 1),
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1) {
|
||||
if constexpr (STAGES == 1) __syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2) {
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1) {
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1) {
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z) {
|
||||
if (slice_id == z) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
if (z > 0) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0) {
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M) {
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float2 w_sz = __half22float2(*(w_szs + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float a_ssum = __half2float(a_ssums[row_wb]);
|
||||
float2 psums =
|
||||
make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum;
|
||||
psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half2* __restrict__ w_szs,
|
||||
half* __restrict__ a_ssums,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
// Not implemented for SM < 800
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void qserve_w4a8_per_chn_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
const torch::Tensor& _w_szs,
|
||||
const torch::Tensor& _a_ssums,
|
||||
torch::Tensor& _out_feats) {
|
||||
// Check input tensor
|
||||
TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor");
|
||||
TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous");
|
||||
TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8");
|
||||
// Check kernel tensor
|
||||
TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor");
|
||||
TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor");
|
||||
TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous");
|
||||
TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8");
|
||||
// Check output tensor
|
||||
TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous");
|
||||
TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half");
|
||||
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
// Check matmul shape
|
||||
TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)");
|
||||
TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats");
|
||||
|
||||
// Check _ascales
|
||||
TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor");
|
||||
TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous");
|
||||
TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half");
|
||||
TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements");
|
||||
|
||||
// Check _wscales
|
||||
TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor");
|
||||
TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous");
|
||||
TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half");
|
||||
TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements");
|
||||
|
||||
// Check _w_szs
|
||||
TORCH_CHECK(_w_szs.is_cuda(), "_w_szs must be a CUDA tensor");
|
||||
TORCH_CHECK(_w_szs.is_contiguous(), "_w_szs must be contiguous");
|
||||
TORCH_CHECK(_w_szs.scalar_type() == torch::kHalf, "_w_szs must be half");
|
||||
TORCH_CHECK(_w_szs.numel() == num_out_channels, "_w_szs must have num_out_channels elements");
|
||||
|
||||
// Check _a_ssums
|
||||
TORCH_CHECK(_a_ssums.is_cuda(), "_a_ssums must be a CUDA tensor");
|
||||
TORCH_CHECK(_a_ssums.is_contiguous(), "_a_ssums must be contiguous");
|
||||
TORCH_CHECK(_a_ssums.scalar_type() == torch::kHalf, "_a_ssums must be half");
|
||||
TORCH_CHECK(_a_ssums.numel() == num_in_feats, "_a_ssums must have num_in_feats elements");
|
||||
|
||||
auto in_feats = reinterpret_cast<int8_t*>(_in_feats.data_ptr<int8_t>());
|
||||
auto kernel = reinterpret_cast<int8_t*>(_kernel.data_ptr<int8_t>());
|
||||
auto w_szs = reinterpret_cast<half2*>(_w_szs.data_ptr());
|
||||
auto a_ssums = reinterpret_cast<half*>(_a_ssums.data_ptr());
|
||||
auto wscales = reinterpret_cast<half2*>(_wscales.data_ptr());
|
||||
auto ascales = reinterpret_cast<half*>(_ascales.data_ptr());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device());
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
if (sm_version >= 80) {
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 256) {
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 128;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else if (num_out_feats >= 128) {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else {
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented qserve_w4a8_per_chn_gemm for current compute capability: ", sm_version);
|
||||
}
|
||||
return;
|
||||
}
|
||||
795
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
Normal file
795
sgl-kernel/csrc/gemm/qserve_w4a8_per_group_gemm.cu
Normal file
@@ -0,0 +1,795 @@
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
// @article{yang2025lserve,
|
||||
// title={LServe: Efficient Long-sequence LLM Serving with Unified Sparse Attention},
|
||||
// author={Yang*, Shang and Guo*, Junxian and Tang, Haotian and Hu, Qinghao and Xiao, Guangxuan and Tang, Jiaming and
|
||||
// Lin, Yujun and Liu, Zhijian and Lu, Yao and Han, Song}, year={2025}
|
||||
// }
|
||||
|
||||
// Adapted from https://github.com/mit-han-lab/omniserve/blob/main/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize = \
|
||||
((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) * \
|
||||
sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) { \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>( \
|
||||
in_feats, \
|
||||
kernel, \
|
||||
zeros, \
|
||||
scales_i8, \
|
||||
wscales, \
|
||||
ascales, \
|
||||
out_feats, \
|
||||
num_in_feats, \
|
||||
num_out_channels, \
|
||||
num_in_channels);
|
||||
|
||||
template <int N>
|
||||
__inline__ __host__ __device__ int get_log_tile(int n) {
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) {
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr) {
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[0]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[2]),
|
||||
"=r"(((unsigned*)(shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4* __restrict__ src, bool mask) {
|
||||
const int cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ __inline__ void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp) {
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*)C_warp)[0]), "=r"(((int*)C_warp)[1]), "=r"(((int*)C_warp)[2]), "=r"(((int*)C_warp)[3])
|
||||
: "r"(((unsigned*)A_shared_warp)[0]),
|
||||
"r"(((unsigned*)A_shared_warp)[1]),
|
||||
"r"(((unsigned*)A_shared_warp)[2]),
|
||||
"r"(((unsigned*)A_shared_warp)[3]),
|
||||
"r"(((unsigned*)B_shared_warp)[0]),
|
||||
"r"(((unsigned*)B_shared_warp)[1]),
|
||||
"r"(((int*)C_warp)[0]),
|
||||
"r"(((int*)C_warp)[1]),
|
||||
"r"(((int*)C_warp)[2]),
|
||||
"r"(((int*)C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_A(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask,
|
||||
bool* preds) {
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
} else {
|
||||
if (preds[global_iter]) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ __inline__ void global_to_share_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) {
|
||||
void* dst_ptr = (void*)(dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
} else {
|
||||
if (mask) *(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void global_to_share_one_stage_zeros(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int global_ncols,
|
||||
int cta_offset_m,
|
||||
int cta_offset_n,
|
||||
int global_iter_k,
|
||||
int shared_iter_k,
|
||||
bool mask) {
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1) {
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
} else {
|
||||
if (local_mask) {
|
||||
*(uint4*)dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ __inline__ void
|
||||
share_to_reg_one_stage_A(int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*)(src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ __inline__ void share_to_reg_one_stage_B(
|
||||
int8_t* src,
|
||||
int8_t* dst,
|
||||
int8_t* zeros,
|
||||
int8_t* scales_i8,
|
||||
int warp_offset_m,
|
||||
int warp_offset_n,
|
||||
int k_0_0,
|
||||
int k_0_1,
|
||||
int shared_iters) {
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
|
||||
uint4 loaded =
|
||||
*((uint4*)(src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol + k_0_1 * INTRIN_K +
|
||||
threadIdx.x);
|
||||
uint32_t loaded_0 = loaded.x & 0x0F0F0F0F;
|
||||
uint32_t loaded_4 = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_2 = loaded.y & 0x0F0F0F0F;
|
||||
uint32_t loaded_6 = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_1 = loaded.z & 0x0F0F0F0F;
|
||||
uint32_t loaded_5 = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_3 = loaded.w & 0x0F0F0F0F;
|
||||
uint32_t loaded_7 = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
|
||||
auto ptr = (uint32_t*)dst + shared_iter * 8;
|
||||
int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32;
|
||||
uint32_t packed_scales = *reinterpret_cast<uint32_t*>(scales_i8 + scales_zeros_offset);
|
||||
uint32_t packed_zeros = *reinterpret_cast<uint32_t*>(zeros + scales_zeros_offset);
|
||||
|
||||
uint32_t scale_0 = packed_scales & 0xFF;
|
||||
uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000);
|
||||
uint32_t ptr_0 = loaded_0 * scale_0;
|
||||
uint32_t ptr_1 = loaded_1 * scale_0;
|
||||
ptr[0] = __vadd4(ptr_0, zero_point_0);
|
||||
ptr[1] = __vadd4(ptr_1, zero_point_0);
|
||||
|
||||
uint32_t scale_1 = (packed_scales & 0xFF00) >> 8;
|
||||
uint32_t zero_point_1 = __byte_perm(packed_zeros, 0, 0x00001111);
|
||||
uint32_t ptr_2 = loaded_2 * scale_1;
|
||||
uint32_t ptr_3 = loaded_3 * scale_1;
|
||||
ptr[2] = __vadd4(ptr_2, zero_point_1);
|
||||
ptr[3] = __vadd4(ptr_3, zero_point_1);
|
||||
|
||||
uint32_t scale_2 = (packed_scales & 0xFF0000) >> 16;
|
||||
uint32_t zero_point_2 = __byte_perm(packed_zeros, 0, 0x00002222);
|
||||
uint32_t ptr_4 = loaded_4 * scale_2;
|
||||
uint32_t ptr_5 = loaded_5 * scale_2;
|
||||
ptr[4] = __vadd4(ptr_4, zero_point_2);
|
||||
ptr[5] = __vadd4(ptr_5, zero_point_2);
|
||||
|
||||
uint32_t scale_3 = (packed_scales & 0xFF000000) >> 24;
|
||||
uint32_t zero_point_3 = __byte_perm(packed_zeros, 0, 0x00003333);
|
||||
uint32_t ptr_6 = loaded_6 * scale_3;
|
||||
uint32_t ptr_7 = loaded_7 * scale_3;
|
||||
ptr[6] = __vadd4(ptr_6, zero_point_3);
|
||||
ptr[7] = __vadd4(ptr_7, zero_point_3);
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
int8_t* __restrict__ zeros,
|
||||
int8_t* __restrict__ scales_i8,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
const int log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE +
|
||||
(threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++) {
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
|
||||
B_hoisted, B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
zeros, zeros_shared + (k_0_0_ld)*CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
scales_i8,
|
||||
scales_i8_shared + (k_0_0_ld)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
0,
|
||||
k_0_0_ld < gemm_iters);
|
||||
|
||||
if constexpr (STAGES > 1) __pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0],
|
||||
zeros_shared,
|
||||
scales_i8_shared,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
0,
|
||||
0,
|
||||
WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage)*CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage)*CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2],
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2],
|
||||
zeros_shared_this_compute_stage,
|
||||
scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m,
|
||||
warp_offset_n,
|
||||
k_0_0 + (iter_k == SHARED_K_ITERS - 1),
|
||||
(iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32(
|
||||
(void*)(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*)(A_shared_warp + i_0_3 * 16),
|
||||
(void*)(B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1) {
|
||||
if constexpr (STAGES == 1) __syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2) {
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
|
||||
B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage,
|
||||
K,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k + 1,
|
||||
k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
zeros,
|
||||
zeros_shared + (ld_stage)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
scales_i8,
|
||||
scales_i8_shared + (ld_stage)*CTA_N,
|
||||
N,
|
||||
cta_offset_m,
|
||||
cta_offset_n,
|
||||
k_0_0_ld,
|
||||
iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1) {
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1) {
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z) {
|
||||
if (slice_id == z) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
if (z > 0) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0) {
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared
|
||||
[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
|
||||
((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) +
|
||||
(threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0) {
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M) {
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float2 psums =
|
||||
make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x *= wscale.x * ascale;
|
||||
psums.y *= wscale.y * ascale;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(
|
||||
int8_t* __restrict__ A,
|
||||
int8_t* __restrict__ B,
|
||||
int8_t* __restrict__ zeros,
|
||||
int8_t* __restrict__ scales_i8,
|
||||
half2* __restrict__ wscales,
|
||||
half* __restrict__ ascales,
|
||||
half* __restrict__ C,
|
||||
int M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
// Not implemented for SM < 800
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void qserve_w4a8_per_group_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _zeros,
|
||||
const torch::Tensor& _scales_i8,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
torch::Tensor& _out_feats) {
|
||||
// Check input tensor
|
||||
TORCH_CHECK(_in_feats.is_cuda(), "_in_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_in_feats.dim() == 2, "_in_feats must be a 2D tensor");
|
||||
TORCH_CHECK(_in_feats.is_contiguous(), "_in_feats must be contiguous");
|
||||
TORCH_CHECK(_in_feats.scalar_type() == torch::kInt8, "_in_feats must be int8");
|
||||
// Check kernel tensor
|
||||
TORCH_CHECK(_kernel.is_cuda(), "_kernel must be a CUDA tensor");
|
||||
TORCH_CHECK(_kernel.dim() == 2, "_kernel must be a 2D tensor");
|
||||
TORCH_CHECK(_kernel.is_contiguous(), "_kernel must be contiguous");
|
||||
TORCH_CHECK(_kernel.scalar_type() == torch::kInt8, "_kernel must be int8");
|
||||
// Check output tensor
|
||||
TORCH_CHECK(_out_feats.is_cuda(), "_out_feats must be a CUDA tensor");
|
||||
TORCH_CHECK(_out_feats.is_contiguous(), "_out_feats must be contiguous");
|
||||
TORCH_CHECK(_out_feats.scalar_type() == torch::kHalf, "_out_feats must be half");
|
||||
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
// Check matmul shape
|
||||
TORCH_CHECK(num_out_channels == _kernel.size(0), "num_out_channels must be equal to _kernel.size(0)");
|
||||
TORCH_CHECK(num_in_feats == num_out_feats, "num_in_feats must be equal to num_out_feats");
|
||||
|
||||
// Check _ascales
|
||||
TORCH_CHECK(_ascales.is_cuda(), "_ascales must be a CUDA tensor");
|
||||
TORCH_CHECK(_ascales.is_contiguous(), "_ascales must be contiguous");
|
||||
TORCH_CHECK(_ascales.scalar_type() == torch::kHalf, "_ascales must be half");
|
||||
TORCH_CHECK(_ascales.numel() == num_in_feats, "_ascales must have num_in_feats elements");
|
||||
|
||||
// Check _wscales
|
||||
TORCH_CHECK(_wscales.is_cuda(), "_wscales must be a CUDA tensor");
|
||||
TORCH_CHECK(_wscales.is_contiguous(), "_wscales must be contiguous");
|
||||
TORCH_CHECK(_wscales.scalar_type() == torch::kHalf, "_wscales must be half");
|
||||
TORCH_CHECK(_wscales.numel() == num_out_channels, "_wscales must have num_out_channels elements");
|
||||
|
||||
// Check _scales_i8
|
||||
TORCH_CHECK(_scales_i8.is_cuda(), "_scales_i8 must be a CUDA tensor");
|
||||
TORCH_CHECK(_scales_i8.dim() == 2, "_scales_i8 must be a 2D tensor");
|
||||
TORCH_CHECK(_scales_i8.is_contiguous(), "_scales_i8 must be contiguous");
|
||||
TORCH_CHECK(_scales_i8.scalar_type() == torch::kInt8, "_scales_i8 must be int8");
|
||||
TORCH_CHECK(num_in_channels % _scales_i8.size(0) == 0, "num_in_channels must be divisible by _scales_i8.size(0)");
|
||||
TORCH_CHECK(num_out_channels == _scales_i8.size(1), "num_out_channels must be equal to _scales_i8.size(1)");
|
||||
|
||||
// Check _zeros
|
||||
TORCH_CHECK(_zeros.is_cuda(), "_zeros must be a CUDA tensor");
|
||||
TORCH_CHECK(_zeros.dim() == 2, "_zeros must be a 2D tensor");
|
||||
TORCH_CHECK(_zeros.is_contiguous(), "_zeros must be contiguous");
|
||||
TORCH_CHECK(_zeros.scalar_type() == torch::kInt8, "_zeros must be int8");
|
||||
TORCH_CHECK(num_in_channels % _zeros.size(0) == 0, "num_in_channels must be divisible by _zeros.size(0)");
|
||||
TORCH_CHECK(num_out_channels == _zeros.size(1), "num_out_channels must be equal to _zeros.size(1)");
|
||||
|
||||
// Check group size
|
||||
auto group_size = num_in_channels / _scales_i8.size(0);
|
||||
TORCH_CHECK(group_size == 128, "group_size must be 128");
|
||||
|
||||
auto in_feats = reinterpret_cast<int8_t*>(_in_feats.data_ptr<int8_t>());
|
||||
auto kernel = reinterpret_cast<int8_t*>(_kernel.data_ptr<int8_t>());
|
||||
auto zeros = reinterpret_cast<int8_t*>(_zeros.data_ptr<int8_t>());
|
||||
auto scales_i8 = reinterpret_cast<int8_t*>(_scales_i8.data_ptr<int8_t>());
|
||||
auto wscales = reinterpret_cast<half2*>(_wscales.data_ptr());
|
||||
auto ascales = reinterpret_cast<half*>(_ascales.data_ptr());
|
||||
// auto options =
|
||||
// torch::TensorOptions().dtype(torch::kHalf).device(_in_feats.device());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(_in_feats.get_device());
|
||||
auto sm_version = getSMVersion();
|
||||
if (sm_version >= 80) {
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 128) {
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else if (num_out_feats >= 128) {
|
||||
if (num_in_channels <= 4096) {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
} else {
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented qserve_w4a8_per_group_gemm for current compute capability: ", sm_version);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -404,3 +404,24 @@ void convert_vertical_slash_indexes_mergehead(
|
||||
* From XGrammar
|
||||
*/
|
||||
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
|
||||
|
||||
/*
|
||||
* From QServe
|
||||
*/
|
||||
void qserve_w4a8_per_chn_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
const torch::Tensor& _w_szs,
|
||||
const torch::Tensor& _a_ssums,
|
||||
torch::Tensor& _out_feats);
|
||||
|
||||
void qserve_w4a8_per_group_gemm(
|
||||
const torch::Tensor& _in_feats,
|
||||
const torch::Tensor& _kernel,
|
||||
const torch::Tensor& _zeros,
|
||||
const torch::Tensor& _scales_i8,
|
||||
const torch::Tensor& _wscales,
|
||||
const torch::Tensor& _ascales,
|
||||
torch::Tensor& _out_feats);
|
||||
|
||||
@@ -36,6 +36,8 @@ from sgl_kernel.gemm import (
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
|
||||
@@ -197,3 +197,47 @@ def scaled_fp4_quant(
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def qserve_w4a8_per_chn_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
w_szs: torch.Tensor,
|
||||
a_ssums: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def qserve_w4a8_per_group_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
scales_i8: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
118
sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py
Normal file
118
sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import qserve_w4a8_per_chn_gemm
|
||||
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
|
||||
def convert_to_qserve_format(qweight, scale, zero):
|
||||
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
|
||||
in_features = qweight.shape[1]
|
||||
out_features = qweight.shape[0]
|
||||
assert in_features % 32 == 0, "Input features must be divisible by 32"
|
||||
assert out_features % 32 == 0, "Output features must be divisible by 32"
|
||||
|
||||
# ---- Repack the weight ---- #
|
||||
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
||||
qweight_unpack_reorder = (
|
||||
qweight.reshape(
|
||||
out_features // 32,
|
||||
2,
|
||||
2,
|
||||
8,
|
||||
in_features // 32,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
)
|
||||
.permute(0, 4, 3, 6, 1, 5, 2, 7)
|
||||
.contiguous()
|
||||
)
|
||||
qweight_unpack_reorder = (
|
||||
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
|
||||
.contiguous()
|
||||
.to(torch.int8)
|
||||
)
|
||||
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
||||
# [16, 0, 17, 1, ...]
|
||||
qweight_unpack_repacked = (
|
||||
qweight_unpack_reorder[..., 1] << 4
|
||||
) + qweight_unpack_reorder[..., 0]
|
||||
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
|
||||
out_features // 32, in_features // 32, 32, 16
|
||||
)
|
||||
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
|
||||
out_features, in_features // 2
|
||||
).contiguous()
|
||||
|
||||
# ---- Pack the scales ---- #
|
||||
scale = scale.reshape(out_features).to(torch.float16).contiguous()
|
||||
szero = zero.reshape(out_features).to(torch.float16).contiguous() * scale
|
||||
|
||||
return qweight_unpack_repacked, scale, szero
|
||||
|
||||
|
||||
# INT4 Quantization
|
||||
def asym_quantize_tensor(tensor):
|
||||
tensor_min = tensor.min(dim=-1, keepdim=True)[0]
|
||||
tensor_max = tensor.max(dim=-1, keepdim=True)[0]
|
||||
q_min = 0
|
||||
q_max = 15
|
||||
tensor_scale = (tensor_max - tensor_min) / (q_max - q_min)
|
||||
tensor_zero = q_min - torch.round(tensor_min / tensor_scale)
|
||||
tensor_q = torch.clamp(
|
||||
torch.round(tensor / tensor_scale) + tensor_zero, q_min, q_max
|
||||
).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16), tensor_zero.to(torch.int8)
|
||||
|
||||
|
||||
# INT8 Quantization
|
||||
def sym_quantize_tensor(tensor):
|
||||
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
|
||||
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16)
|
||||
|
||||
|
||||
def torch_w4a8_per_chn_gemm(a, b, a_scale, b_scale, b_zero, out_dtype):
|
||||
print(a.shape)
|
||||
print(b.shape)
|
||||
print(b_zero.shape)
|
||||
o = torch.matmul(
|
||||
a.to(torch.float16), (b.to(torch.float16) - b_zero.to(torch.float16)).t()
|
||||
)
|
||||
o = o * a_scale.view(-1, 1) * b_scale.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
# to avoid overflow, multiply 0.01
|
||||
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
|
||||
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
|
||||
|
||||
# symmetric quantize a
|
||||
a_q, a_scale = sym_quantize_tensor(a)
|
||||
# asymmetric quantize b
|
||||
b_q, b_scale, b_zero = asym_quantize_tensor(b)
|
||||
# convert to qserve format
|
||||
b_q_format, b_scale_format, b_szero_format = convert_to_qserve_format(
|
||||
b_q, b_scale, b_zero
|
||||
)
|
||||
|
||||
# cal sum of every row of a
|
||||
a_sum = a.sum(dim=-1, keepdim=True).to(torch.float16)
|
||||
out = qserve_w4a8_per_chn_gemm(
|
||||
a_q, b_q_format, b_scale_format, a_scale, b_szero_format, a_sum
|
||||
)
|
||||
ref_out = torch_w4a8_per_chn_gemm(a_q, b_q, a_scale, b_scale, b_zero, out_dtype)
|
||||
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
183
sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py
Normal file
183
sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import qserve_w4a8_per_group_gemm
|
||||
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
|
||||
def convert_to_qserve_format(qweight, chn_scale, scale_i8, zero_i8, group_size):
|
||||
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
|
||||
in_features = qweight.shape[1]
|
||||
out_features = qweight.shape[0]
|
||||
assert in_features % 32 == 0, "Input features must be divisible by 32"
|
||||
assert out_features % 32 == 0, "Output features must be divisible by 32"
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
assert (
|
||||
in_features % group_size == 0
|
||||
), "Input features must be divisible by group_size"
|
||||
|
||||
# ---- Repack the weight ---- #
|
||||
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
||||
qweight_unpack_reorder = (
|
||||
qweight.reshape(
|
||||
out_features // 32,
|
||||
2,
|
||||
2,
|
||||
8,
|
||||
in_features // 32,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
)
|
||||
.permute(0, 4, 3, 6, 1, 5, 2, 7)
|
||||
.contiguous()
|
||||
)
|
||||
qweight_unpack_reorder = (
|
||||
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
|
||||
.contiguous()
|
||||
.to(torch.int8)
|
||||
)
|
||||
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
||||
# [16, 0, 17, 1, ...]
|
||||
qweigth_unpack_repacked = (
|
||||
qweight_unpack_reorder[..., 1] << 4
|
||||
) + qweight_unpack_reorder[..., 0]
|
||||
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
|
||||
out_features // 32, in_features // 32, 32, 16
|
||||
)
|
||||
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
|
||||
out_features, in_features // 2
|
||||
)
|
||||
|
||||
# ---- Pack the scales ---- #
|
||||
chn_scale = chn_scale.reshape(out_features)
|
||||
|
||||
scale_i8 = (
|
||||
scale_i8.reshape(out_features, in_features // group_size)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
scale_i8 = scale_i8.reshape(in_features // group_size, out_features // 32, 32)
|
||||
scale_i8 = (
|
||||
scale_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
|
||||
.transpose(-2, -1)
|
||||
.contiguous()
|
||||
)
|
||||
scale_i8 = scale_i8.reshape(in_features // group_size, out_features).contiguous()
|
||||
|
||||
# ---- Pack the zeros ---- #
|
||||
zero_i8 = -zero_i8
|
||||
# zero_i8 = zero_i8.int() # convert to 2-complement
|
||||
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(out_features, in_features // group_size)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
zero_i8 = zero_i8.reshape(in_features // group_size, out_features // 32, 32)
|
||||
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
|
||||
.transpose(-2, -1)
|
||||
.contiguous()
|
||||
)
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(in_features // group_size, out_features).contiguous() * scale_i8
|
||||
)
|
||||
|
||||
return qweigth_unpack_repacked, chn_scale, scale_i8, zero_i8
|
||||
|
||||
|
||||
# Progressive Group INT4 Quantization
|
||||
def progressive_group_quantize_tensor(tensor, group_size):
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
assert (
|
||||
tensor.shape[-1] % group_size == 0
|
||||
), "Input features must be divisible by group_size"
|
||||
# Channel scale
|
||||
# NOTE(HandH1998): use protective quantization range
|
||||
chn_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 119
|
||||
tensor_i8 = torch.clamp(torch.round(tensor / chn_scale), -119, 119)
|
||||
|
||||
# Group scale
|
||||
tensor_i8 = tensor_i8.reshape(-1, group_size)
|
||||
tensor_i8_min = tensor_i8.min(dim=-1, keepdim=True)[0]
|
||||
tensor_i8_max = tensor_i8.max(dim=-1, keepdim=True)[0]
|
||||
q_min = 0
|
||||
q_max = 15
|
||||
scale_i8 = torch.round((tensor_i8_max - tensor_i8_min) / (q_max - q_min))
|
||||
zero_i8 = q_min - torch.round(tensor_i8_min / scale_i8)
|
||||
tensor_q = (
|
||||
torch.clamp(torch.round(tensor_i8 / scale_i8) + zero_i8, q_min, q_max)
|
||||
.reshape(tensor.shape[0], -1)
|
||||
.to(torch.int8)
|
||||
)
|
||||
return (
|
||||
tensor_q,
|
||||
chn_scale.to(torch.float16),
|
||||
scale_i8.reshape(tensor.shape[0], -1).to(torch.int8),
|
||||
zero_i8.reshape(tensor.shape[0], -1).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# INT8 Quantization
|
||||
def sym_quantize_tensor(tensor):
|
||||
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
|
||||
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16)
|
||||
|
||||
|
||||
def torch_w4a8_per_group_gemm(
|
||||
a, b, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
|
||||
):
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
b_dq = (
|
||||
b.reshape(-1, group_size).to(torch.float32)
|
||||
- b_zero_i8.reshape(-1, 1).to(torch.float32)
|
||||
) * b_scale_i8.reshape(-1, 1).to(torch.float32)
|
||||
b_dq = b_dq.reshape(b.shape[0], b.shape[1])
|
||||
o = torch.matmul(a.to(torch.float32), b_dq.t())
|
||||
o = o * a_scale.view(-1, 1) * b_chn_scale.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, group_size, out_dtype, device):
|
||||
# to avoid overflow, multiply 0.01
|
||||
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
|
||||
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
|
||||
|
||||
# symmetric quantize a
|
||||
a_q, a_scale = sym_quantize_tensor(a)
|
||||
# asymmetric quantize b
|
||||
b_q, b_chn_scale, b_scale_i8, b_zero_i8 = progressive_group_quantize_tensor(
|
||||
b, group_size
|
||||
)
|
||||
# convert to qserve format
|
||||
b_q_format, b_chn_scale_format, b_scale_i8_format, b_zero_i8_format = (
|
||||
convert_to_qserve_format(b_q, b_chn_scale, b_scale_i8, b_zero_i8, group_size)
|
||||
)
|
||||
|
||||
out = qserve_w4a8_per_group_gemm(
|
||||
a_q,
|
||||
b_q_format,
|
||||
b_zero_i8_format,
|
||||
b_scale_i8_format,
|
||||
b_chn_scale_format,
|
||||
a_scale,
|
||||
)
|
||||
ref_out = torch_w4a8_per_group_gemm(
|
||||
a_q, b_q, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
|
||||
)
|
||||
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
def test_accuracy(M, N, K, group_size, out_dtype):
|
||||
_test_accuracy_once(M, N, K, group_size, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user