[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
166
python/sglang/srt/layers/quantization/quant_utils.py
Normal file
166
python/sglang/srt/layers/quantization/quant_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k,
|
||||
size_n // pack_factor,
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
assert (
|
||||
quant_type.is_integer()
|
||||
), "Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, (
|
||||
"to have group zero points, group_size must be provided "
|
||||
"(-1 group_size is channelwise)"
|
||||
)
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = (
|
||||
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
||||
)
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
||||
)
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
@@ -250,6 +250,15 @@ set(SOURCES
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
"csrc/common_extension.cc"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
|
||||
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
|
||||
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
@@ -209,6 +208,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -303,6 +308,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
||||
|
||||
/*
|
||||
* From Sparse Flash Attention
|
||||
|
||||
255
sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu
Normal file
255
sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu
Normal file
@@ -0,0 +1,255 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "kernel.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in awq_marlin_repack_kernel
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k, "b_q_weight.size(0) = ", b_q_weight.size(0), " is not size_k = ", size_k);
|
||||
TORCH_CHECK(
|
||||
(size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ",
|
||||
b_q_weight.size(1),
|
||||
", size_n = ",
|
||||
size_n,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
25
sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h
Normal file
25
sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
106
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
106
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
96
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh
Normal file
96
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh
Normal file
@@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) {
|
||||
return elems[i];
|
||||
}
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@@ -0,0 +1,83 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "marlin.cuh"
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
333
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu
Normal file
333
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu
Normal file
@@ -0,0 +1,333 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in gptq_marlin_repack_kernel
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
(size_k / pack_factor) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ",
|
||||
b_q_weight.size(0),
|
||||
", size_k = ",
|
||||
size_k,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(
|
||||
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
// }
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
// }
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
Normal file
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
Normal file
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
1804
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1804
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
1112
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
1112
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
328
sgl-kernel/include/scalar_type.hpp
Normal file
328
sgl-kernel/include/scalar_type.hpp
Normal file
@@ -0,0 +1,328 @@
|
||||
#pragma once
|
||||
|
||||
// For TORCH_CHECK
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace sglang {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(
|
||||
uint8_t exponent,
|
||||
uint8_t mantissa,
|
||||
bool signed_,
|
||||
int32_t bias,
|
||||
bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) {
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) {
|
||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
TORCH_CHECK(
|
||||
nan_repr != NAN_IEEE_754,
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int { return acc + member_id_field_width<decltype(member)>(); }, 0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result, auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const {
|
||||
return signed_;
|
||||
}
|
||||
constexpr bool is_integer() const {
|
||||
return exponent == 0;
|
||||
}
|
||||
constexpr bool is_floating_point() const {
|
||||
return exponent > 0;
|
||||
}
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const {
|
||||
return bias != 0;
|
||||
}
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
TORCH_CHECK(mantissa <= 52 && exponent <= 11, "Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
TORCH_CHECK(exponent < 11, "Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret =
|
||||
"float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only && nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// Fixed width style names, generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
static inline constexpr auto kInt4 = kS4;
|
||||
static inline constexpr auto kUint4 = kU4;
|
||||
static inline constexpr auto kUint4b8 = kU4B8;
|
||||
static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
static inline constexpr auto kHalf = kFE5M10;
|
||||
static inline constexpr auto kFloat16 = kHalf;
|
||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||
|
||||
static inline constexpr auto kFloat16Id = kFloat16.id();
|
||||
}; // namespace sglang
|
||||
@@ -18,12 +18,15 @@ limitations under the License.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
@@ -323,6 +326,15 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
namespace marlin_moe_wna16 {
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
} // namespace marlin_moe_wna16
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -495,6 +507,31 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
namespace flash {
|
||||
/*
|
||||
|
||||
@@ -29,6 +29,7 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
@@ -55,6 +56,11 @@ from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
from sgl_kernel.marlin import (
|
||||
awq_marlin_moe_repack,
|
||||
awq_marlin_repack,
|
||||
gptq_marlin_repack,
|
||||
)
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
|
||||
223
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
223
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2
|
||||
), "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
is_marlin=True,
|
||||
)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * (
|
||||
sorted_token_ids.size(0) // block_size_m
|
||||
)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk_ids.shape[1] * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
|
||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
||||
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type1.id,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_full_k=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type2.id,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_full_k=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
):
|
||||
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (
|
||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||
), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert (
|
||||
self.is_signed()
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (
|
||||
not self.is_signed() or self.size_bits <= 64
|
||||
), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
||||
138
sgl-kernel/tests/test_marlin_repack.py
Normal file
138
sgl-kernel/tests/test_marlin_repack.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_marlin_repack
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.quant_utils import (
|
||||
get_pack_factor,
|
||||
pack_cols,
|
||||
quantize_weights,
|
||||
)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
|
||||
|
||||
def awq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(np.uint32)
|
||||
|
||||
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("group_size", [16, 32])
|
||||
def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
|
||||
tile_k, tile_n = 16, 64
|
||||
size_k = k_tiles * tile_k
|
||||
size_n = n_tiles * tile_n
|
||||
pack_factor = 32 // num_bits
|
||||
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
w_ref, q_w, s, zp = quantize_weights(
|
||||
b_weight, scalar_types.uint4, group_size, zero_points=True
|
||||
)
|
||||
|
||||
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
|
||||
out_gpu = awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)
|
||||
assert out_gpu.is_cuda and out_gpu.dtype == torch.int32
|
||||
|
||||
expected_cols = size_n * tile_k // pack_factor
|
||||
assert list(out_gpu.shape) == [size_k // tile_k, expected_cols]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out_gpu, q_w_marlin)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
301
test/srt/test_int4_kernel.py
Normal file
301
test/srt/test_int4_kernel.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm")
|
||||
|
||||
# from sglang.srt.layers.moe.topk import select_experts
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
|
||||
# from vllm.model_executor.layers. import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def stack_and_dev(tensors: list[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def marlin_fused_moe(
|
||||
N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size
|
||||
):
|
||||
quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
if ep_size > 1:
|
||||
local_e = E // ep_size
|
||||
e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e]
|
||||
e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
s1_l = []
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(n=K)
|
||||
quant_res = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n=N)
|
||||
quant_res = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
# topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
# hidden_states=a,
|
||||
# router_logits=score,
|
||||
# top_k=topk,
|
||||
# num_expert_group=E,
|
||||
# use_grouped_topk=False,
|
||||
# renormalize=False,
|
||||
# topk_group=None,
|
||||
# )
|
||||
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=E,
|
||||
expert_map=e_map,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
num_bits=num_bits,
|
||||
is_k_full=True,
|
||||
)
|
||||
return marlin_output, torch_output
|
||||
|
||||
|
||||
class TestW8A8Int8FusedMoE(unittest.TestCase):
|
||||
DTYPES = [torch.float16]
|
||||
M = [1, 16]
|
||||
N = [128]
|
||||
K = [256]
|
||||
E = [4, 10]
|
||||
TOP_KS = [2, 4]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
NUM_BITS = [4]
|
||||
EP_SIZE = [1, 4]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w4a8_int8_fused_moe(
|
||||
self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2
|
||||
w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
marlin_out, ref_out = marlin_fused_moe(
|
||||
N=N,
|
||||
E=E,
|
||||
K=K,
|
||||
a=a,
|
||||
w1=w1_fp16,
|
||||
w2=w2_fp16,
|
||||
num_bits=num_bits,
|
||||
group_size=-1,
|
||||
act_order=False,
|
||||
score=score,
|
||||
topk=topk,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
# Check results
|
||||
if (
|
||||
torch.mean(
|
||||
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
)
|
||||
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
> 0.1
|
||||
):
|
||||
print(f"marlin_out: {marlin_out}")
|
||||
print(f"ref_out: {ref_out}")
|
||||
print(
|
||||
torch.mean(
|
||||
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
)
|
||||
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
)
|
||||
torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0)
|
||||
|
||||
def test_w4a8_int8_fused_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.TOP_KS,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
self.NUM_BITS,
|
||||
self.EP_SIZE,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
block_size=params[5],
|
||||
dtype=params[6],
|
||||
seed=params[7],
|
||||
num_bits=params[8],
|
||||
ep_size=params[9],
|
||||
):
|
||||
self._w4a8_int8_fused_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
14
test/srt/test_w4a8.py
Normal file
14
test/srt/test_w4a8.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
qweight = torch.randn(10, 10, device="cuda")
|
||||
s1_scales = torch.randn(10, device="cuda")
|
||||
input_scales = torch.randn(10, device="cuda")
|
||||
s1_szeros = torch.randn(10, device="cuda")
|
||||
input_sum = torch.randn(10, device="cuda")
|
||||
output_buffer = torch.randn(10, device="cuda")
|
||||
|
||||
torch.ops.sgl_kernel.gemm_forward_cuda.default(
|
||||
x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer
|
||||
)
|
||||
Reference in New Issue
Block a user