[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)
### What this PR does / why we need it?
- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
- `enable_expert_parallel=True`
- quantization is `w8a8_dynamic`
- `EP <= 16`
- `dynamic_eplb` is disabled
- `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ class MoECommType(Enum):
|
|||||||
ALLGATHER = 0
|
ALLGATHER = 0
|
||||||
MC2 = 1
|
MC2 = 1
|
||||||
ALLTOALL = 2
|
ALLTOALL = 2
|
||||||
FUSED_ALLTOALL = 3
|
FUSED_MC2 = 3
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -62,11 +62,8 @@ def set_ascend_forward_context(
|
|||||||
|
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||||
get_moe_comm_method
|
get_moe_comm_method
|
||||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
|
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
||||||
# TODO: remove this after moe_comm_type selection logic is finalized
|
is_mtp_model)
|
||||||
if is_mtp_model:
|
|
||||||
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
|
||||||
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
|
||||||
forward_context.moe_comm_type = moe_comm_type
|
forward_context.moe_comm_type = moe_comm_type
|
||||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||||
|
|
||||||
@@ -210,28 +207,29 @@ def get_mc2_mask():
|
|||||||
|
|
||||||
|
|
||||||
def select_moe_comm_method(num_tokens: int,
|
def select_moe_comm_method(num_tokens: int,
|
||||||
vllm_config: VllmConfig) -> Optional[MoECommType]:
|
vllm_config: VllmConfig,
|
||||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
is_mtp_model=False) -> Optional[MoECommType]:
|
||||||
are designed for expert parallelism.
|
"""Select the MoE communication method according to parallel settings,
|
||||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
device generation, token count, and quantization.
|
||||||
number of tokens. This is based on the observation that all-gather is more
|
|
||||||
efficient than all-to-all when running on A2.
|
|
||||||
|
|
||||||
a. For A2, we choose from MC2 and all-gather.
|
1. Non-MoE models return `None`.
|
||||||
|
2. Without expert parallel, fall back to all-gather.
|
||||||
b. For A3, we choose from MC2 and all-to-all.
|
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
|
||||||
|
and the DP size is large enough; otherwise use all-gather.
|
||||||
In both cases, we use MC2 when the number of tokens is smaller than
|
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
|
||||||
a its capacity threshold.
|
quantization with small EP size, no dynamic_eplb, and not in MTP
|
||||||
|
mode; otherwise use MC2 within capacity or all-to-all.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_tokens (int): The number of tokens in the current batch.
|
num_tokens (int): The number of tokens in the current batch.
|
||||||
|
vllm_config (VllmConfig): Runtime configuration for the model.
|
||||||
|
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the soc version is unsupported.
|
ValueError: If the soc version is unsupported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MoECommType: The selected MoE communication method.
|
MoECommType | None: The selected MoE communication method.
|
||||||
"""
|
"""
|
||||||
if not is_moe_model(vllm_config):
|
if not is_moe_model(vllm_config):
|
||||||
return None
|
return None
|
||||||
@@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
|
||||||
).world_size <= 16 and (not dynamic_eplb)
|
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
|
||||||
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
|
if num_tokens <= mc2_tokens_capacity:
|
||||||
else MoECommType.FUSED_ALLTOALL
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
|
||||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
else:
|
||||||
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
return moe_comm_type
|
return moe_comm_type
|
||||||
|
|||||||
@@ -132,6 +132,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to anbale dynamic EPLB
|
# Whether to anbale dynamic EPLB
|
||||||
"DYNAMIC_EPLB":
|
"DYNAMIC_EPLB":
|
||||||
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
||||||
|
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
|
||||||
|
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
||||||
|
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||||
and not shared_expert_dp_enabled():
|
and not shared_expert_dp_enabled():
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import torch
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||||
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
|
|||||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||||
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
|
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
|
||||||
moe_config)
|
|
||||||
|
|
||||||
|
|
||||||
class MoECommMethod(ABC):
|
class MoECommMethod(ABC):
|
||||||
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
|
|||||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||||
|
|
||||||
|
|
||||||
class FusedAlltoAllCommImpl(MoECommMethod):
|
class FusedMC2CommImpl(MoECommMethod):
|
||||||
"""This implementation is for the scenarios listed below:
|
"""This implementation is for the scenarios listed below:
|
||||||
1. `enable_expert_parallel=True`.
|
1. `enable_expert_parallel=True`.
|
||||||
2. `npu_grouped_matmul` is available.
|
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||||
|
3. `enable_expert_parallel=False` is not supported.
|
||||||
|
|
||||||
This implementation uses all-to-all communication to exchange tokens
|
This implementation uses the MC2 communication method, which is optimized for
|
||||||
between data parallel ranks before and after the MLP computation. It should
|
Communication and Computation parallelism on Ascend devices.
|
||||||
have better performance than AllGatherCommImpl when DP size > 1.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_token_dispatcher(self):
|
def _get_token_dispatcher(self):
|
||||||
return TokenDispatcherWithAll2AllV(
|
return TokenDispatcherWithMC2()
|
||||||
top_k=self.moe_config.experts_per_token,
|
|
||||||
num_experts=self.moe_config.num_experts,
|
|
||||||
num_local_experts=self.moe_config.num_local_experts)
|
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||||
|
|
||||||
def fused_experts(
|
def fused_experts(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor | list[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor | list[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
|||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
global_num_experts: Optional[int] = None,
|
global_num_experts: Optional[int] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
w2_scale_bias: torch.Tensor = None,
|
w2_scale_bias: torch.Tensor = None,
|
||||||
w1_offset: Optional[torch.Tensor] = None,
|
w1_offset: Optional[torch.Tensor] = None,
|
||||||
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
|||||||
dynamic_eplb: bool = False,
|
dynamic_eplb: bool = False,
|
||||||
mc2_mask: torch.Tensor = None,
|
mc2_mask: torch.Tensor = None,
|
||||||
pertoken_scale: Optional[torch.Tensor] = None):
|
pertoken_scale: Optional[torch.Tensor] = None):
|
||||||
|
assert not (
|
||||||
|
w1_scale is None or w2_scale is None
|
||||||
|
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight1=w1,
|
weight1=w1[0],
|
||||||
weight2=w2,
|
weight2=w2[0],
|
||||||
expert_idx=topk_ids,
|
expert_idx=topk_ids,
|
||||||
scale1=w1_scale,
|
scale1=w1_scale[0],
|
||||||
scale2=w2_scale,
|
scale2=w2_scale[0],
|
||||||
probs=topk_weights.to(torch.float32),
|
probs=topk_weights.to(torch.float32),
|
||||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
max_output_size=65536,
|
max_output_size=65536,
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -130,7 +130,8 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
|||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(prefetch_stream):
|
||||||
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
torch_npu.npu_prefetch(
|
||||||
|
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
|
||||||
x_dependency, mlp_gate_up_prefetch_size)
|
x_dependency, mlp_gate_up_prefetch_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -185,7 +186,8 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
|||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(prefetch_stream):
|
||||||
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
torch_npu.npu_prefetch(
|
||||||
|
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
|
||||||
x_dependency, mlp_down_prefetch_size)
|
x_dependency, mlp_down_prefetch_size)
|
||||||
forward_context.layer_idx += 1
|
forward_context.layer_idx += 1
|
||||||
return
|
return
|
||||||
@@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_type in {
|
if moe_comm_type in {
|
||||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
|
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
|
||||||
} or forward_context.sp_enabled:
|
} or forward_context.sp_enabled:
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
|
|||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
w2 = [layer.w2_weight]
|
w2 = [layer.w2_weight]
|
||||||
w2_scale = [layer.w2_weight_scale]
|
w2_scale = [layer.w2_weight_scale]
|
||||||
|
|
||||||
fused_flag = get_forward_context(
|
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||||
).moe_comm_type == MoECommType.FUSED_ALLTOALL
|
== MoECommType.FUSED_MC2
|
||||||
|
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
pertoken_scale=pertoken_scale,
|
pertoken_scale=pertoken_scale,
|
||||||
w1=w1[0] if fused_flag else w1,
|
w1=w1,
|
||||||
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
|
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||||
w2=w2[0] if fused_flag else w2,
|
w2=w2,
|
||||||
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
|
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
|
|||||||
@@ -430,7 +430,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
||||||
# nodes. So here we check whether recompute_scheduler_enable is True.
|
# nodes. So here we check whether recompute_scheduler_enable is True.
|
||||||
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
|
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
|
||||||
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
|
potential_max_num_tokens,
|
||||||
|
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||||
|
|
||||||
def _sync_metadata_across_dp(
|
def _sync_metadata_across_dp(
|
||||||
self, num_tokens: int,
|
self, num_tokens: int,
|
||||||
@@ -2203,7 +2204,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||||
if self.max_num_tokens > mc2_tokens_capacity and \
|
if self.max_num_tokens > mc2_tokens_capacity and \
|
||||||
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
|
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
|
||||||
self._dummy_run(mc2_tokens_capacity,
|
self._dummy_run(mc2_tokens_capacity,
|
||||||
with_prefill=True,
|
with_prefill=True,
|
||||||
is_profile=True)
|
is_profile=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user