[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)

### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.

Plan / Roadmap

1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.

Other notes

* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.

- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-12 21:10:20 +08:00
committed by GitHub
parent 1a70564e7c
commit 992271b027
7 changed files with 764 additions and 26 deletions

View File

@@ -43,6 +43,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
@@ -57,6 +58,62 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
def unified_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
moe_comm_method: Optional[MoECommMethod] = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
mlp_output = apply_mlp(
permuted_hidden_states,
w1,
w2,
expert_tokens,
group_list_type=group_list_type,
)
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
return hidden_states
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
max_row_per_ep_rank: int, num_tokens: int,
top_k: int) -> tuple[torch.Tensor, torch.Tensor]: