[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.
The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test case fixed.
- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -43,7 +43,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
@@ -58,60 +57,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def unified_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
|
||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
||||
mlp_output = apply_mlp(
|
||||
permuted_hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
max_row_per_ep_rank: int, num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user