[Kernel] add custom op DispatchGmmCombineDecode (#4139)
#### What this PR does / why we need it? add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl, python Api, pytest. vLLM version: v0.11.0 vLLM main:24d6314718- vLLM version: v0.12.0 - vLLM main:ad32e3e19cSigned-off-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -595,6 +595,64 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
const at::Tensor &x,
|
||||
const at::Tensor &expert_ids,
|
||||
const at::Tensor &gmm1_permuted_weight,
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
int64_t moe_expert_num,
|
||||
int64_t shared_expert_num,
|
||||
int64_t shared_expert_rank_num,
|
||||
int64_t quant_mode,
|
||||
int64_t global_bs)
|
||||
{
|
||||
auto x_shape = x.sizes();
|
||||
int bs = x_shape[0];
|
||||
int h = x_shape[1];
|
||||
|
||||
at::Tensor output = at::empty({bs, h}, x.options());
|
||||
|
||||
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
|
||||
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
|
||||
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
|
||||
|
||||
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char *group_ep_ptr = &group_ep_chrs[0];
|
||||
EXEC_NPU_CMD(
|
||||
// op api
|
||||
aclnnDispatchGmmCombineDecode,
|
||||
// input tensors
|
||||
x,
|
||||
expert_ids,
|
||||
gmm1_permuted_weight,
|
||||
gmm1_permuted_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
expert_smooth_scales,
|
||||
expert_scales,
|
||||
//input attrs
|
||||
group_ep_ptr,
|
||||
ep_rank_size,
|
||||
ep_rank_id,
|
||||
moe_expert_num,
|
||||
shared_expert_num,
|
||||
shared_expert_rank_num,
|
||||
quant_mode,
|
||||
global_bs,
|
||||
// output tensors
|
||||
output,
|
||||
ep_recv_count);
|
||||
return {output, ep_recv_count};
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
@@ -818,6 +876,19 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
|
||||
ops.def(
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||
" Tensor gmm1_permuted_weight_scale,"
|
||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
|
||||
" str group_ep='',"
|
||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||
" int quant_mode=0,"
|
||||
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
|
||||
);
|
||||
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);
|
||||
|
||||
ops.def(
|
||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
|
||||
" Tensor group_list, *,"
|
||||
|
||||
Reference in New Issue
Block a user