[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)

Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: yych0745 <1398089567@qq.com>
Co-authored-by: HandH1998 <1335248067@qq.com>
Co-authored-by: 弋云 <yiyun.wyt@antgroup.com>
Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
AniZpZ
2025-07-02 14:21:25 +08:00
committed by GitHub
parent b3fa5dc3c8
commit 8e03b641ba
27 changed files with 6104 additions and 1 deletions

View File

@@ -18,12 +18,15 @@ limitations under the License.
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/all.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <tuple>
#include <vector>
#include "scalar_type.hpp"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
@@ -323,6 +326,15 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
namespace marlin_moe_wna16 {
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
} // namespace marlin_moe_wna16
/*
* From csrc/speculative
*/
@@ -495,6 +507,31 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
std::optional<at::Generator> gen);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
namespace flash {
/*