[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
255
sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu
Normal file
255
sgl-kernel/csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu
Normal file
@@ -0,0 +1,255 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "kernel.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in awq_marlin_repack_kernel
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k, "b_q_weight.size(0) = ", b_q_weight.size(0), " is not size_k = ", size_k);
|
||||
TORCH_CHECK(
|
||||
(size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ",
|
||||
b_q_weight.size(1),
|
||||
", size_n = ",
|
||||
size_n,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
25
sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h
Normal file
25
sgl-kernel/csrc/moe/marlin_moe_wna16/core/registration.h
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
106
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
106
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
96
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh
Normal file
96
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin/marlin.cuh
Normal file
@@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) {
|
||||
return elems[i];
|
||||
}
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@@ -0,0 +1,83 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "marlin.cuh"
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
333
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu
Normal file
333
sgl-kernel/csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu
Normal file
@@ -0,0 +1,333 @@
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async in gptq_marlin_repack_kernel
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
|
||||
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
|
||||
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
(size_k / pack_factor) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ",
|
||||
b_q_weight.size(0),
|
||||
", size_k = ",
|
||||
size_k,
|
||||
", pack_factor = ",
|
||||
pack_factor);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(
|
||||
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
|
||||
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
// }
|
||||
|
||||
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
// }
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
Normal file
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
Normal file
89
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
Normal file
109
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
||||
1804
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1804
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
1112
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
1112
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user