[2/N][Feat] Add MC2 communication method for MoE layers (#2469)

### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-26 19:05:23 +08:00
committed by GitHub
parent 5d8ec28009
commit a6bb502e70
11 changed files with 506 additions and 410 deletions

View File

@@ -15,22 +15,84 @@
# limitations under the License.
#
from typing import Callable, Optional
from typing import Any, Callable, Optional
import torch
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MC2CommImpl,
MoECommMethod)
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
moe_comm_method: Optional[MoECommMethod] = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
mlp_output = apply_mlp(
permuted_hidden_states,
w1,
w2,
expert_tokens,
group_list_type=group_list_type,
)
moe_comm_method.unpermute(mlp_output, hidden_states)
return hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
@@ -97,7 +159,7 @@ def forward_oot(
moe_comm_method = get_forward_context().moe_comm_method
return unified_fused_experts(
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -109,5 +171,112 @@ def forward_oot(
)
class AscendFusedMoE(FusedMoE):
def __init__(
self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":
moe_comm_method_name = "allgathercommimpl"
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
return final_hidden_states
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.forward_oot = forward_oot