[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)
#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.
This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .
- vLLM version: v0.13.0
- vLLM main:
7157596103
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
@@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
group_list_type = None
|
||||
expert_tokens = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||
@@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
group_list_type = 1
|
||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
x=hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
gmm1_permuted_weight=w1[0],
|
||||
gmm1_permuted_weight_scale=w1_scale[0],
|
||||
gmm2_weight=w2[0],
|
||||
gmm2_weight_scale=w2_scale[0],
|
||||
gmm1_permuted_weight=w1,
|
||||
gmm1_permuted_weight_scale=w1_scale,
|
||||
gmm2_weight=w2,
|
||||
gmm2_weight_scale=w2_scale,
|
||||
expert_smooth_scales=None,
|
||||
expert_scales=topk_weights.to(torch.float32),
|
||||
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
@@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out)
|
||||
return FusedExpertsResult(routed_out=out,
|
||||
group_list_type=group_list_type,
|
||||
expert_tokens=expert_tokens)
|
||||
|
||||
Reference in New Issue
Block a user