[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)
### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -640,8 +640,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const at::Tensor &expert_scales,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
const c10::optional<at::Tensor> &x_active_mask,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
@@ -674,8 +675,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
gmm1_permuted_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
expert_smooth_scales,
|
||||
expert_scales,
|
||||
expert_smooth_scales,
|
||||
x_active_mask,
|
||||
//input attrs
|
||||
group_ep_ptr,
|
||||
ep_rank_size,
|
||||
@@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||
" Tensor gmm1_permuted_weight_scale,"
|
||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
|
||||
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
|
||||
" Tensor? x_active_mask=None,"
|
||||
" str group_ep='',"
|
||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||
|
||||
Reference in New Issue
Block a user