[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)
#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.
This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .
- vLLM version: v0.13.0
- vLLM main:
7157596103
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
@@ -242,15 +242,14 @@ def select_moe_comm_method(num_tokens: int,
|
||||
ascend_config = get_ascend_config()
|
||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
|
||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
|
||||
not dynamic_eplb)
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 16 and (
|
||||
not is_draft_model) and (not dynamic_eplb)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
fused_decode_enable = fused_mc2_enable
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
fused_decode_enable = fused_mc2_enable and get_ep_group(
|
||||
).world_size <= 16 and (not is_draft_model)
|
||||
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
fused_decode_enable = fused_mc2_enable and \
|
||||
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||
@@ -258,8 +257,7 @@ def select_moe_comm_method(num_tokens: int,
|
||||
else:
|
||||
fused_prefill_enable = fused_mc2_enable
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
fused_prefill_enable = fused_mc2_enable and get_ep_group(
|
||||
).world_size <= 16 and (not is_draft_model)
|
||||
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
fused_prefill_enable = False
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
|
||||
|
||||
@@ -131,7 +131,7 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb.
|
||||
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
|
||||
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
|
||||
# with W8A8, non-dynamic-eplb. And MTP layer must be W8A8.
|
||||
# with W8A8. And MTP layer must be W8A8.
|
||||
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
||||
# Whether to anbale balance scheduling
|
||||
|
||||
@@ -54,12 +54,15 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
|
||||
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \
|
||||
self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list
|
||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||
if self.model.quant_config is not None:
|
||||
self.expert_weight_names = [
|
||||
"w13_weight_list", "w2_weight_list",
|
||||
"w13_weight_scale_fp32_list", "w13_weight_offset",
|
||||
"w2_weight_scale_list", "w2_weight_offset"
|
||||
"w2_weight_scale_list", "w2_weight_offset",
|
||||
"w2_weight_scale_fp32_list"
|
||||
]
|
||||
else:
|
||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||
@@ -97,7 +100,8 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
self.num_dense_layers) + ".mlp.experts." + name
|
||||
if name in [
|
||||
"w13_weight_list", "w2_weight_list",
|
||||
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
|
||||
"w13_weight_scale_fp32_list", "w2_weight_scale_list",
|
||||
"w2_weight_scale_fp32_list"
|
||||
]:
|
||||
expert_tensor = self.param_dict[complete_name][0]
|
||||
expert_tensor = expert_tensor.clone()
|
||||
@@ -118,7 +122,7 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
if name in [
|
||||
"w13_weight_list", "w2_weight_list",
|
||||
"w13_weight_scale_fp32_list",
|
||||
"w2_weight_scale_list"
|
||||
"w2_weight_scale_list", "w2_weight_scale_fp32_list"
|
||||
]:
|
||||
per_expert_param.append(
|
||||
self.param_dict["model.layers." + str(layer_idx) +
|
||||
|
||||
@@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
group_list_type = None
|
||||
expert_tokens = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||
@@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
group_list_type = 1
|
||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
x=hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
gmm1_permuted_weight=w1[0],
|
||||
gmm1_permuted_weight_scale=w1_scale[0],
|
||||
gmm2_weight=w2[0],
|
||||
gmm2_weight_scale=w2_scale[0],
|
||||
gmm1_permuted_weight=w1,
|
||||
gmm1_permuted_weight_scale=w1_scale,
|
||||
gmm2_weight=w2,
|
||||
gmm2_weight_scale=w2_scale,
|
||||
expert_smooth_scales=None,
|
||||
expert_scales=topk_weights.to(torch.float32),
|
||||
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
@@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out)
|
||||
return FusedExpertsResult(routed_out=out,
|
||||
group_list_type=group_list_type,
|
||||
expert_tokens=expert_tokens)
|
||||
|
||||
@@ -254,7 +254,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.w2_weight_scale_list
|
||||
w2_scale = layer.w2_weight_scale_fp32_list \
|
||||
if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.w13_weight_scale_fp32]
|
||||
@@ -333,11 +334,16 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
weight.clone()
|
||||
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_scale_fp32_list = [
|
||||
weight.clone()
|
||||
for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0)
|
||||
]
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
del layer.w2_weight_scale_fp32
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user