[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)
#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.
This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .
- vLLM version: v0.13.0
- vLLM main:
7157596103
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
@@ -636,10 +636,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
|
||||
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
const at::Tensor &x,
|
||||
const at::Tensor &expert_ids,
|
||||
const at::Tensor &gmm1_permuted_weight,
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const at::TensorList &gmm1_permuted_weight,
|
||||
const at::TensorList &gmm1_permuted_weight_scale,
|
||||
const at::TensorList &gmm2_weight,
|
||||
const at::TensorList &gmm2_weight_scale,
|
||||
const at::Tensor &expert_scales,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &x_active_mask,
|
||||
@@ -660,7 +660,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
|
||||
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
|
||||
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
|
||||
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
|
||||
auto opts = expert_ids.options().dtype(at::kLong);
|
||||
at::Tensor expert_token_nums = at::empty({num_local_experts}, opts);
|
||||
|
||||
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
@@ -689,8 +690,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
global_bs,
|
||||
// output tensors
|
||||
output,
|
||||
ep_recv_count);
|
||||
return {output, ep_recv_count};
|
||||
expert_token_nums);
|
||||
return {output, expert_token_nums};
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
@@ -1287,16 +1288,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
|
||||
ops.def(
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||
" Tensor gmm1_permuted_weight_scale,"
|
||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor[] gmm1_permuted_weight,"
|
||||
" Tensor[] gmm1_permuted_weight_scale,"
|
||||
" Tensor[] gmm2_weight, Tensor[] gmm2_weight_scale,"
|
||||
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
|
||||
" Tensor? x_active_mask=None,"
|
||||
" str group_ep='',"
|
||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||
" int quant_mode=0,"
|
||||
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
|
||||
" int global_bs=0) -> (Tensor output, Tensor expert_token_nums)"
|
||||
);
|
||||
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user