[Kernel] add custom op DispatchGmmCombineDecode (#4139)

#### What this PR does / why we need it?
add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl,
python Api, pytest.

vLLM version: v0.11.0
vLLM main:
24d6314718


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
GuoRen868
2025-12-06 17:33:14 +08:00
committed by GitHub
parent cb42564942
commit 4bd1030842
29 changed files with 7851 additions and 27 deletions

View File

@@ -154,6 +154,37 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
const at::Tensor &x,
const at::Tensor &expert_ids,
const at::Tensor &gmm1_permuted_weight,
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales,
c10::string_view group_ep,
int64_t ep_rank_size,
int64_t ep_rank_id,
int64_t moe_expert_num,
int64_t shared_expert_num,
int64_t shared_expert_rank_num,
int64_t quant_mode,
int64_t global_bs)
{
auto x_shape = x.sizes();
int bs = x_shape[0];
int h = x_shape[1];
at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta));
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta));
return {output, ep_recv_count};
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
@@ -255,6 +286,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
// Grouped matmul swiglu quant weight nz tensor list
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
// dispatch_gmm_combine_decode meta implementation
ops.impl("dispatch_gmm_combine_decode", &vllm_ascend::meta::dispatch_gmm_combine_decode_meta);
// batch_matmul_transpose
ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose);
// Lightning indexer