Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t num_batches,
|
||||
int64_t sm_count = 0,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc);
|
||||
|
||||
void downcast_fp8(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
int64_t cuda_stream);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -221,7 +235,6 @@ void bmm_fp8(
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
@@ -258,6 +271,7 @@ torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -527,35 +596,6 @@ void transfer_kv_direct(
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
@@ -597,32 +637,6 @@ void top_p_sampling_from_probs(
|
||||
void top_k_mask_logits(
|
||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
namespace flash {
|
||||
/*
|
||||
* From fa2 sparse
|
||||
|
||||
Reference in New Issue
Block a user