[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
40
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "gptq_marlin/marlin.cuh"
|
||||
#include "gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <
|
||||
typename scalar_t, // compute dtype, half or nv_float16
|
||||
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user