diff --git a/python/sglang/srt/layers/quantization/quant_utils.py b/python/sglang/srt/layers/quantization/quant_utils.py new file mode 100644 index 000000000..59a1b1fdc --- /dev/null +++ b/python/sglang/srt/layers/quantization/quant_utils.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py + +from typing import Optional + +import numpy +import torch +from sgl_kernel.scalar_type import ScalarType + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 739b2b1c2..1bc289519 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -250,6 +250,15 @@ set(SOURCES "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/kvcacheio/transfer.cu" "csrc/common_extension.cc" + "csrc/moe/marlin_moe_wna16/ops.cu" + "csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu" + "csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" + "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu" + "csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 941d0f836..0aee1c1bf 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "sgl_kernel_ops.h" - TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/allreduce @@ -209,6 +208,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); + m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor"); + m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack); + + m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); + m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack); + /* * From csrc/speculative */ @@ -303,6 +308,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " "maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + m.def( + "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," + "Tensor sorted_token_ids," + "Tensor! expert_ids, Tensor! num_tokens_past_padded," + "Tensor! topk_weights, int moe_block_size, int top_k, " + "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "int size_m, int size_n, int size_k," + "bool is_full_k, bool use_atomic_add," + "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm); /* * From Sparse Flash Attention diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu new file mode 100644 index 000000000..63d9dd904 --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu @@ -0,0 +1,255 @@ +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "core/registration.h" +#include "gptq_marlin/marlin.cuh" +#include "kernel.h" + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async in awq_marlin_repack_kernel +#else + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +#define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + awq_marlin_repack_kernel \ + <<>>(b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK(b_q_weight.size(0) == size_k, "b_q_weight.size(0) = ", b_q_weight.size(0), " is not size_k = ", size_k); + TORCH_CHECK( + (size_n / pack_factor) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", + b_q_weight.size(1), + ", size_n = ", + size_n, + ", pack_factor = ", + pack_factor); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); + torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); + + // Get ptrs + uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4) + CALL_IF(8) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + } + + return out; +} + +torch::Tensor +awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { + int const pack_factor = 32 / num_bits; + auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); + return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h new file mode 100644 index 000000000..9caf0795a --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#define SGLANG_IMPLIES(p, q) (!(p) || (q)) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py new file mode 100644 index 000000000..833d074ea --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{'true' if has_act_order else 'false'}}, " + "{{'true' if has_zp else 'false'}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) + +# int8 with zero point case (sglang::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + has_zp = "B" not in scalar_type + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): + + has_act_order = group_blocks == 0 + if has_zp and has_act_order: + continue + if thread_configs[2] == 256: + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + has_act_order=has_act_order, + has_zp=has_zp, + group_blocks=group_blocks, + is_zp_float=False, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh new file mode 100644 index 000000000..29db8c934 --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { + return (a + b - 1) / b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..618c86f95 --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin_dtypes.cuh @@ -0,0 +1,83 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include +#include + +#include "marlin.cuh" + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +namespace MARLIN_NAMESPACE_NAME { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu new file mode 100644 index 000000000..209d13e06 --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu @@ -0,0 +1,333 @@ +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "gptq_marlin/marlin.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async in gptq_marlin_repack_kernel +#else + +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + gptq_marlin_repack_kernel \ + <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +torch::Tensor +gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + + TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + + // Verify B + TORCH_CHECK( + (size_k / pack_factor) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", + b_q_weight.size(0), + ", size_k = ", + size_k, + ", pack_factor = ", + pack_factor); + TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); + torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } + + return out; +} + +torch::Tensor gptq_marlin_repack_meta( + torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { + int const pack_factor = 32 / num_bits; + auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); + return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options); +} + +#endif + +// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { +// m.impl("gptq_marlin_repack", &gptq_marlin_repack); +// } + +// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { +// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); +// } + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h new file mode 100644 index 000000000..9b36b477d --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h @@ -0,0 +1,40 @@ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "gptq_marlin/marlin.cuh" +#include "gptq_marlin/marlin_dtypes.cuh" +#include "scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce + +namespace MARLIN_NAMESPACE_NAME { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu new file mode 100644 index 000000000..1e3d923ae --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu @@ -0,0 +1,89 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu new file mode 100644 index 000000000..513ddc2ed --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu @@ -0,0 +1,109 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu new file mode 100644 index 000000000..eebe9d3da --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu @@ -0,0 +1,109 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu new file mode 100644 index 000000000..9adc6623a --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu @@ -0,0 +1,89 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu new file mode 100644 index 000000000..66ca7e36a --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu @@ -0,0 +1,109 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu new file mode 100644 index 000000000..21fdf0c1a --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu @@ -0,0 +1,109 @@ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h new file mode 100644 index 000000000..14df0aee5 --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -0,0 +1,1804 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "gptq_marlin/marlin.cuh" +#include "gptq_marlin/marlin_dtypes.cuh" +#include "scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace MARLIN_NAMESPACE_NAME + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline typename ScalarType::FragB dequant(int q, typename ScalarType::FragB& frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline typename ScalarType::FragB dequant(int q, typename ScalarType::FragB& frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, typename ScalarType::FragB& frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2( + *reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline typename ScalarType::FragB dequant(int q, typename ScalarType::FragB& frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, typename ScalarType::FragB& frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub(typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void +sub_zp(typename ScalarType::FragB& frag_b, typename ScalarType::scalar_t2& frag_zp, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const sglang::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = sglang::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); + int4* sh_block_topk_weights_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4; + scalar_t2* sh_block_topk_weights = reinterpret_cast(sh_block_topk_weights_int4); + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; +#pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = + reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); +#pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = + reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + + if (mul_topk_weights) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh_new; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_b; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + int a_remaining_load_count_in_slice = stages; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || a_remaining_load_count_in_slice > 0) { + a_remaining_load_count_in_slice--; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) sorted_row = sh_block_sorted_ids[row] / top_k; + int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant(zp_quant_0, frag_zp_0); + dequant(zp_quant_1, frag_zp_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + } +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant(b_quant_0, frag_b0); + dequant(b_quant_1, frag_b1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); + + } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); + } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + if (is_new_zp) frag_zpf[k2][j] = __hmul2(frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); + scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue; + } + + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && !has_zp) { + res = __hmul2(res, s[0]); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + fetch_col_scale_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + a_remaining_load_count_in_slice = 0; + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && !has_zp) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + if (slice_row) a_remaining_load_count_in_slice = stages; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu new file mode 100644 index 000000000..9d5f7c26b --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -0,0 +1,1112 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "core/registration.h" +#include "kernel.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, + int size_m, + int size_k, + int top_k) {}; + +} // namespace marlin + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, + std::optional const& c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, + torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, + torch::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, + int size_m, + int size_k, + int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const& th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + + // shm size for block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4 * 2; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define __GET_IF( \ + W_TYPE, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + HAS_ACT_ORDER, \ + HAS_ZP, \ + GROUP_BLOCKS, \ + NUM_THREADS, \ + IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && has_act_order == HAS_ACT_ORDER && \ + has_zp == HAS_ZP && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + HAS_ACT_ORDER, \ + HAS_ZP, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) + +#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) + +#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) + +#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) + +// We currently have 4-bit models only with group_blocks == 4 +#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, true) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) + +template +MarlinFuncPtr get_marlin_kernel( + const sglang::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + GPTQ_GET_IF_M1(sglang::kU4B8, 8, 8, 256) + GPTQ_GET_IF_M1(sglang::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M234(sglang::kU4B8, 16, 4, 256) + GPTQ_GET_IF_M234(sglang::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M1(sglang::kU8B128, 8, 8, 256) + GPTQ_GET_IF_M1(sglang::kU8B128, 8, 4, 128) + + GPTQ_GET_IF_M234(sglang::kU8B128, 16, 4, 256) + GPTQ_GET_IF_M234(sglang::kU8B128, 8, 4, 128) + + AWQ_GET_IF_M1(sglang::kU4, 8, 8, 256) + AWQ_GET_IF_M1(sglang::kU4, 8, 4, 128) + + AWQ_GET_IF_M234(sglang::kU4, 16, 4, 256) + AWQ_GET_IF_M234(sglang::kU4, 8, 4, 128) + + return kernel; +} + +template +exec_config_t determine_exec_config( + const sglang::ScalarType& q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, max_shared_mem / (cache_size + 1024)); + allow_count = max(min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm( + const void* A, + const void* B, + void* C, + void* C_tmp, + void* s, + void* zp, + void* g_idx, + void* perm, + void* a_tmp, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_past_padded, + void* topk_weights, + int moe_block_size, + int top_k, + bool mul_topk_weights, + bool is_ep, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + sglang::ScalarType const& q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k, + int thread_n, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + TORCH_CHECK( + q_type == sglang::kU4 || q_type == sglang::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", + q_type.str()); + } else { + TORCH_CHECK( + q_type == sglang::kU4B8 || q_type == sglang::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + kernel = permute_cols_kernel<64>; + else + TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size); + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem = ", + max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + TORCH_CHECK( + false, + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_bits = ", + num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + // clang-format on +} + +} // namespace MARLIN_NAMESPACE_NAME + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, + std::optional const& c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, + torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, + torch::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + sglang::ScalarType const b_q_type = sglang::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + TORCH_CHECK(moe_block_size % 16 == 0, "unsupported moe_block_size=", moe_block_size); + TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64, "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); + + // Verify B + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + (size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", + b_q_weight.size(1), + ", size_k = ", + size_k, + ", tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(2) = ", + b_q_weight.size(2), + " is not divisible by tile_size = ", + MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK( + c.size(0) == size_m * top_k, "Shape mismatch: c.size(0) = ", c.size(0), ", size_m * topk = ", size_m * top_k); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), ", size_n = ", size_n); + } else { + c = torch::empty({size_m * top_k, size_n}, options); + } + + // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; + auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (use_fp32_reduce && !use_atomic_add) { + // max num of threadblocks is sms * 4 + long max_c_tmp_size = min( + (long)size_n * sorted_token_ids.size(0), (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + if (moe_block_size == 8) max_c_tmp_size *= 2; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); + } else { + c_tmp = torch::empty({0}, options_fp32); + } + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + torch::Tensor g_idx, perm, a_tmp; + ; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK( + (g_idx.size(-1) == 0 && perm.size(-1) == 0) || (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", + g_idx.size(-1), + " and perm.size(-1) = ", + perm.size(-1), + ", where size_k = ", + size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + + if (has_act_order) { + a_tmp = torch::empty({size_m * top_k, size_k}, options); + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = torch::empty({0}, options); + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(1) = ", b_scales.size(1)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + + if (has_zp) { + TORCH_CHECK(b_q_type == sglang::kU4, "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == sglang::kU4B8 || b_q_type == sglang::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK( + a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(2) == size_n, "b_zeros dim 2 = ", b_zeros.size(2), " is not size_n = ", size_n); + TORCH_CHECK( + num_groups == b_zeros.size(1), "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK( + b_zeros.size(1) == num_groups, "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); + TORCH_CHECK( + b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", + b_zeros.size(2), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify workspace size + TORCH_CHECK( + size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n; + int min_workspace_size = min(max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4); + TORCH_CHECK( + workspace.numel() >= min_workspace_size, + "workspace.numel = ", + workspace.numel(), + " is below min_workspace_size = ", + min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), + moe_block_size, + top_k, + mul_topk_weights, + is_ep, + size_m, + size_n, + size_k, + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), + moe_block_size, + top_k, + mul_topk_weights, + is_ep, + size_m, + size_n, + size_k, + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); + } else { + TORCH_CHECK(false, "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/sgl-kernel/include/scalar_type.hpp b/sgl-kernel/include/scalar_type.hpp new file mode 100644 index 000000000..8a837a2b5 --- /dev/null +++ b/sgl-kernel/include/scalar_type.hpp @@ -0,0 +1,328 @@ +#pragma once + +// For TORCH_CHECK +#include + +namespace sglang { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK( + nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace sglang diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index abea81e3f..d836f848a 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -18,12 +18,15 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include "scalar_type.hpp" + #define _CONCAT(A, B) A##B #define CONCAT(A, B) _CONCAT(A, B) @@ -323,6 +326,15 @@ void scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); +namespace marlin_moe_wna16 { + +torch::Tensor +gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); + +} // namespace marlin_moe_wna16 + /* * From csrc/speculative */ @@ -495,6 +507,31 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, std::optional gen); +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, + std::optional const& c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, + torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, + torch::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float); namespace flash { /* diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 2353004ce..65a037d54 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -29,6 +29,7 @@ from sgl_kernel.elementwise import ( rmsnorm, silu_and_mul, ) +from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, @@ -55,6 +56,11 @@ from sgl_kernel.kvcacheio import ( transfer_kv_per_layer, transfer_kv_per_layer_mla, ) +from sgl_kernel.marlin import ( + awq_marlin_moe_repack, + awq_marlin_repack, + gptq_marlin_repack, +) from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, diff --git a/sgl-kernel/python/sgl_kernel/fused_moe.py b/sgl-kernel/python/sgl_kernel/fused_moe.py new file mode 100644 index 000000000..f9322e228 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/fused_moe.py @@ -0,0 +1,223 @@ +import functools +from typing import Optional + +import torch +from sgl_kernel.scalar_type import scalar_types + + +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + assert num_bits == 4 + return scalar_types.uint4 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import ( + moe_align_block_size, + try_get_optimal_moe_config, + ) + + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2 + ), "Hidden size mismatch w2" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + assert num_bits in [4, 8] + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + topk = topk_ids.shape[1] + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + None, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_size_m, global_num_experts + ) + + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * ( + sorted_token_ids.size(0) // block_size_m + ) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms * 4) + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache13 = torch.empty( + (M * topk_ids.shape[1] * max(2 * N, K),), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N] + intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) + intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K] + intermediate_cache3 = intermediate_cache3.view(-1, K) + + use_atomic_add = ( + hidden_states.dtype == torch.half + or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + ) + + intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( + hidden_states, + intermediate_cache1, + w1, + w1_scale, + w1_zeros, + g_idx1, + sort_indices1, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type_id=scalar_type1.id, + size_m=M, + size_n=2 * N, + size_k=K, + is_full_k=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + if expert_map is not None: + intermediate_cache3.zero_() + + intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( + intermediate_cache2, + intermediate_cache3, + w2, + w2_scale, + w2_zeros, + g_idx2, + sort_indices2, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=1, + mul_topk_weights=True, + is_ep=expert_map is not None, + b_q_type_id=scalar_type2.id, + size_m=M * topk, + size_n=K, + size_k=N, + is_full_k=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ).view(-1, topk, K) + + output = hidden_states if inplace else torch.empty_like(hidden_states) + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output + ) + + +def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + return torch.empty_like(hidden_states) diff --git a/sgl-kernel/python/sgl_kernel/marlin.py b/sgl-kernel/python/sgl_kernel/marlin.py new file mode 100644 index 000000000..407834227 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/marlin.py @@ -0,0 +1,44 @@ +import torch + + +def gptq_marlin_repack( + b_q_weight, + perm, + size_k, + size_n, + num_bits, +): + torch.ops.sgl_kernel.gptq_marlin_repack.default( + b_q_weight, + perm, + size_k, + size_n, + num_bits, + ) + + +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +def awq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) + for e in range(num_experts): + output[e] = torch.ops.sgl_kernel.awq_marlin_repack( + b_q_weight[e], size_k, size_n, num_bits + ) + return output diff --git a/sgl-kernel/python/sgl_kernel/scalar_type.py b/sgl-kernel/python/sgl_kernel/scalar_type.py new file mode 100644 index 000000000..5aeb88651 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/scalar_type.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +_SCALAR_TYPES_ID_MAP = {} + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = max_exponent - exponent_bias + exponent_bias_double + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack("!d", struct.pack("!Q", double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert ( + self.size_bits < 64 or self.size_bits == 64 and self.is_signed() + ), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert ( + self.is_signed() + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack("!d", struct.pack("!Q", min_raw))[0] + else: + assert ( + not self.is_signed() or self.size_bits <= 64 + ), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" + + _SCALAR_TYPES_ID_MAP[val] = self + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = ( + "float" + + str(self.size_bits) + + "_e" + + str(self.exponent) + + "m" + + str(self.mantissa) + ) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_( + cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr + ) -> "ScalarType": + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + assert nan_repr != NanRepr.IEEE_754, ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions" + ) + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/sgl-kernel/tests/test_marlin_repack.py b/sgl-kernel/tests/test_marlin_repack.py new file mode 100644 index 000000000..c0f13f46b --- /dev/null +++ b/sgl-kernel/tests/test_marlin_repack.py @@ -0,0 +1,138 @@ +import math + +import numpy as np +import pytest +import torch +from sgl_kernel import awq_marlin_repack +from sgl_kernel.scalar_type import scalar_types + +from sglang.srt.layers.quantization.quant_utils import ( + get_pack_factor, + pack_cols, + quantize_weights, +) + +GPTQ_MARLIN_TILE = 16 + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)]) +@pytest.mark.parametrize("group_size", [16, 32]) +def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size): + tile_k, tile_n = 16, 64 + size_k = k_tiles * tile_k + size_n = n_tiles * tile_n + pack_factor = 32 // num_bits + + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + + q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + + weight_perm = get_weight_perm(num_bits) + q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + + out_gpu = awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + assert out_gpu.is_cuda and out_gpu.dtype == torch.int32 + + expected_cols = size_n * tile_k // pack_factor + assert list(out_gpu.shape) == [size_k // tile_k, expected_cols] + + torch.cuda.synchronize() + + torch.testing.assert_close(out_gpu, q_w_marlin) + + +if __name__ == "__main__": + import subprocess + + subprocess.call(["pytest", "--tb=short", str(__file__)]) diff --git a/test/srt/test_int4_kernel.py b/test/srt/test_int4_kernel.py new file mode 100644 index 000000000..0665f9b91 --- /dev/null +++ b/test/srt/test_int4_kernel.py @@ -0,0 +1,301 @@ +import itertools +import sys +import unittest + +import torch + +sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm") + +# from sglang.srt.layers.moe.topk import select_experts +from sgl_kernel import fused_marlin_moe +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + +# from vllm.model_executor.layers. import select_experts +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize, +) +from vllm.scalar_type import scalar_types + + +def stack_and_dev(tensors: list[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def torch_moe(a, w1, w2, score, topk, expert_map): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): + """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K,) + A = A.reshape(M, N) + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + return C.reshape(origin_C_shape).to(output_dtype) + + +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): + """This function performs fused moe with per-column int8 quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + # Calculate routing + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype + ) + # Activation function + act_out = SiluAndMul().forward_native(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = per_token_quant_int8(act_out) + + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) + # Apply routing weights and sum + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def marlin_fused_moe( + N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size +): + quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + if ep_size > 1: + local_e = E // ep_size + e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + zeros1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + s1_l = [] + for i in range(w1.shape[0]): + test_perm = torch.randperm(n=K) + quant_res = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + zeros2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + for i in range(w2.shape[0]): + test_perm = torch.randperm(n=N) + quant_res = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + # topk_weights, topk_ids = FusedMoE.select_experts( + # hidden_states=a, + # router_logits=score, + # top_k=topk, + # num_expert_group=E, + # use_grouped_topk=False, + # renormalize=False, + # topk_group=None, + # ) + + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=E, + expert_map=e_map, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + num_bits=num_bits, + is_k_full=True, + ) + return marlin_output, torch_output + + +class TestW8A8Int8FusedMoE(unittest.TestCase): + DTYPES = [torch.float16] + M = [1, 16] + N = [128] + K = [256] + E = [4, 10] + TOP_KS = [2, 4] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + NUM_BITS = [4] + EP_SIZE = [1, 4] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w4a8_int8_fused_moe( + self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size + ): + torch.manual_seed(seed) + a = torch.randn((M, K), dtype=dtype) / 10 + + # Generate int8 weights + w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 + w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 + + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + marlin_out, ref_out = marlin_fused_moe( + N=N, + E=E, + K=K, + a=a, + w1=w1_fp16, + w2=w2_fp16, + num_bits=num_bits, + group_size=-1, + act_order=False, + score=score, + topk=topk, + ep_size=ep_size, + ) + # Check results + if ( + torch.mean( + torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) + ) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + > 0.1 + ): + print(f"marlin_out: {marlin_out}") + print(f"ref_out: {ref_out}") + print( + torch.mean( + torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) + ) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + ) + torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0) + + def test_w4a8_int8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + self.NUM_BITS, + self.EP_SIZE, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + num_bits=params[8], + ep_size=params[9], + ): + self._w4a8_int8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/test_w4a8.py b/test/srt/test_w4a8.py new file mode 100644 index 000000000..75d41ee5f --- /dev/null +++ b/test/srt/test_w4a8.py @@ -0,0 +1,14 @@ +import sgl_kernel +import torch + +x = torch.randn(10, 10, device="cuda") +qweight = torch.randn(10, 10, device="cuda") +s1_scales = torch.randn(10, device="cuda") +input_scales = torch.randn(10, device="cuda") +s1_szeros = torch.randn(10, device="cuda") +input_sum = torch.randn(10, device="cuda") +output_buffer = torch.randn(10, device="cuda") + +torch.ops.sgl_kernel.gemm_forward_cuda.default( + x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer +)