[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)
### What this PR does / why we need it?
- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
- `enable_expert_parallel=True`
- quantization is `w8a8_dynamic`
- `EP <= 16`
- `dynamic_eplb` is disabled
- `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
FUSED_ALLTOALL = 3
|
||||
FUSED_MC2 = 3
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -62,11 +62,8 @@ def set_ascend_forward_context(
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
|
||||
# TODO: remove this after moe_comm_type selection logic is finalized
|
||||
if is_mtp_model:
|
||||
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
||||
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
||||
is_mtp_model)
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
@@ -93,7 +90,7 @@ def set_ascend_forward_context(
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
@@ -210,29 +207,30 @@ def get_mc2_mask():
|
||||
|
||||
|
||||
def select_moe_comm_method(num_tokens: int,
|
||||
vllm_config: VllmConfig) -> Optional[MoECommType]:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
number of tokens. This is based on the observation that all-gather is more
|
||||
efficient than all-to-all when running on A2.
|
||||
vllm_config: VllmConfig,
|
||||
is_mtp_model=False) -> Optional[MoECommType]:
|
||||
"""Select the MoE communication method according to parallel settings,
|
||||
device generation, token count, and quantization.
|
||||
|
||||
a. For A2, we choose from MC2 and all-gather.
|
||||
1. Non-MoE models return `None`.
|
||||
2. Without expert parallel, fall back to all-gather.
|
||||
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
|
||||
and the DP size is large enough; otherwise use all-gather.
|
||||
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
|
||||
quantization with small EP size, no dynamic_eplb, and not in MTP
|
||||
mode; otherwise use MC2 within capacity or all-to-all.
|
||||
|
||||
b. For A3, we choose from MC2 and all-to-all.
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
vllm_config (VllmConfig): Runtime configuration for the model.
|
||||
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
|
||||
|
||||
In both cases, we use MC2 when the number of tokens is smaller than
|
||||
a its capacity threshold.
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
MoECommType: The selected MoE communication method.
|
||||
"""
|
||||
Returns:
|
||||
MoECommType | None: The selected MoE communication method.
|
||||
"""
|
||||
if not is_moe_model(vllm_config):
|
||||
return None
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
@@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
|
||||
ascend_config = get_ascend_config()
|
||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not dynamic_eplb)
|
||||
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
|
||||
else MoECommType.FUSED_ALLTOALL
|
||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
|
||||
else:
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
return moe_comm_type
|
||||
|
||||
@@ -132,6 +132,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Whether to anbale dynamic EPLB
|
||||
"DYNAMIC_EPLB":
|
||||
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
||||
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
|
||||
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
else:
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
|
||||
moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
class FusedMC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
assert not (
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
out = torch.empty_like(hidden_states)
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1,
|
||||
weight2=w2,
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale,
|
||||
scale2=w2_scale,
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1[0],
|
||||
weight2=w2[0],
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale[0],
|
||||
scale2=w2_scale[0],
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return out
|
||||
|
||||
@@ -130,8 +130,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
||||
x_dependency, mlp_gate_up_prefetch_size)
|
||||
torch_npu.npu_prefetch(
|
||||
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
|
||||
x_dependency, mlp_gate_up_prefetch_size)
|
||||
return
|
||||
|
||||
|
||||
@@ -185,8 +186,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
||||
x_dependency, mlp_down_prefetch_size)
|
||||
torch_npu.npu_prefetch(
|
||||
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
|
||||
x_dependency, mlp_down_prefetch_size)
|
||||
forward_context.layer_idx += 1
|
||||
return
|
||||
|
||||
@@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {
|
||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
|
||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_flag = get_forward_context(
|
||||
).moe_comm_type == MoECommType.FUSED_ALLTOALL
|
||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||
== MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1[0] if fused_flag else w1,
|
||||
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
|
||||
w2=w2[0] if fused_flag else w2,
|
||||
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
|
||||
w1=w1,
|
||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
|
||||
@@ -430,7 +430,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
||||
# nodes. So here we check whether recompute_scheduler_enable is True.
|
||||
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
|
||||
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
|
||||
potential_max_num_tokens,
|
||||
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
@@ -1058,7 +1059,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \
|
||||
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
|
||||
self.query_start_loc_pcp_full.cpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
@@ -2203,7 +2204,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def profile_run(self) -> None:
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
if self.max_num_tokens > mc2_tokens_capacity and \
|
||||
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
|
||||
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
|
||||
self._dummy_run(mc2_tokens_capacity,
|
||||
with_prefill=True,
|
||||
is_profile=True)
|
||||
|
||||
Reference in New Issue
Block a user