[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)

#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.

This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .


- vLLM version: v0.13.0
- vLLM main:
7157596103

More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
wangyibo1005
2026-01-07 11:23:42 +08:00
committed by GitHub
parent 086c093347
commit 25baf6df09
18 changed files with 425 additions and 195 deletions

View File

@@ -242,15 +242,14 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
not dynamic_eplb)
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 16 and (
not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_draft_model)
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and \
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
@@ -258,8 +257,7 @@ def select_moe_comm_method(num_tokens: int,
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_draft_model)
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL

View File

@@ -131,7 +131,7 @@ env_variables: Dict[str, Callable[[], Any]] = {
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb.
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
# with W8A8, non-dynamic-eplb. And MTP layer must be W8A8.
# with W8A8. And MTP layer must be W8A8.
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
# Whether to anbale balance scheduling

View File

@@ -54,12 +54,15 @@ class VllmEplbAdaptor(EplbAdaptor):
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w13_weight_offset",
"w2_weight_scale_list", "w2_weight_offset"
"w2_weight_scale_list", "w2_weight_offset",
"w2_weight_scale_fp32_list"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
@@ -97,7 +100,8 @@ class VllmEplbAdaptor(EplbAdaptor):
self.num_dense_layers) + ".mlp.experts." + name
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
"w13_weight_scale_fp32_list", "w2_weight_scale_list",
"w2_weight_scale_fp32_list"
]:
expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
@@ -118,7 +122,7 @@ class VllmEplbAdaptor(EplbAdaptor):
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list"
"w2_weight_scale_list", "w2_weight_scale_fp32_list"
]:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) +

View File

@@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod):
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
group_list_type = None
expert_tokens = None
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
@@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod):
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None."
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
group_list_type = 1
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
gmm1_permuted_weight_scale=w1_scale[0],
gmm2_weight=w2[0],
gmm2_weight_scale=w2_scale[0],
gmm1_permuted_weight=w1,
gmm1_permuted_weight_scale=w1_scale,
gmm2_weight=w2,
gmm2_weight_scale=w2_scale,
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
@@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod):
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return FusedExpertsResult(routed_out=out)
return FusedExpertsResult(routed_out=out,
group_list_type=group_list_type,
expert_tokens=expert_tokens)

View File

@@ -254,7 +254,8 @@ class AscendW8A8DynamicFusedMoEMethod:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_list
w2_scale = layer.w2_weight_scale_fp32_list \
if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
@@ -333,11 +334,16 @@ class AscendW8A8DynamicFusedMoEMethod:
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
]
layer.w2_weight_scale_fp32_list = [
weight.clone()
for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
del layer.w2_weight_scale_fp32
torch.npu.empty_cache()