[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)
### What this PR does / why we need it?
- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
- `enable_expert_parallel=True`
- quantization is `w8a8_dynamic`
- `EP <= 16`
- `dynamic_eplb` is disabled
- `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
FUSED_ALLTOALL = 3
|
||||
FUSED_MC2 = 3
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -62,11 +62,8 @@ def set_ascend_forward_context(
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
|
||||
# TODO: remove this after moe_comm_type selection logic is finalized
|
||||
if is_mtp_model:
|
||||
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
||||
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
||||
is_mtp_model)
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
@@ -93,7 +90,7 @@ def set_ascend_forward_context(
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
@@ -210,29 +207,30 @@ def get_mc2_mask():
|
||||
|
||||
|
||||
def select_moe_comm_method(num_tokens: int,
|
||||
vllm_config: VllmConfig) -> Optional[MoECommType]:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
number of tokens. This is based on the observation that all-gather is more
|
||||
efficient than all-to-all when running on A2.
|
||||
vllm_config: VllmConfig,
|
||||
is_mtp_model=False) -> Optional[MoECommType]:
|
||||
"""Select the MoE communication method according to parallel settings,
|
||||
device generation, token count, and quantization.
|
||||
|
||||
a. For A2, we choose from MC2 and all-gather.
|
||||
1. Non-MoE models return `None`.
|
||||
2. Without expert parallel, fall back to all-gather.
|
||||
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
|
||||
and the DP size is large enough; otherwise use all-gather.
|
||||
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
|
||||
quantization with small EP size, no dynamic_eplb, and not in MTP
|
||||
mode; otherwise use MC2 within capacity or all-to-all.
|
||||
|
||||
b. For A3, we choose from MC2 and all-to-all.
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
vllm_config (VllmConfig): Runtime configuration for the model.
|
||||
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
|
||||
|
||||
In both cases, we use MC2 when the number of tokens is smaller than
|
||||
a its capacity threshold.
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
MoECommType: The selected MoE communication method.
|
||||
"""
|
||||
Returns:
|
||||
MoECommType | None: The selected MoE communication method.
|
||||
"""
|
||||
if not is_moe_model(vllm_config):
|
||||
return None
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
@@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
|
||||
ascend_config = get_ascend_config()
|
||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not dynamic_eplb)
|
||||
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
|
||||
else MoECommType.FUSED_ALLTOALL
|
||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
|
||||
else:
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
return moe_comm_type
|
||||
|
||||
Reference in New Issue
Block a user