Files
xc-llm-ascend/tests/ut/ops/test_moe_comm_method.py
Mercykid-bash 29e2f9a43e Bugfix: Align expert map shapes with redundant experts in EPLB adjustment (#5285)
#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-06 17:22:36 +08:00

228 lines
9.6 KiB
Python

from unittest.mock import MagicMock, patch
import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)
class TestMoECommMethod(TestBase):
def setUp(self):
# Mock FusedMoEConfig
self.moe_config = MagicMock(spec=FusedMoEConfig)
self.moe_config.num_experts = 8
self.moe_config.num_local_experts = 2
self.moe_config.experts_per_token = 2
self.moe_config.tp_group = MagicMock()
self.moe_config.tp_group.device_group = MagicMock()
self.moe_config.dp_size = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
self.moe_config.dp_group = MagicMock()
self.moe_config.global_redundant_expert_num = 0
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = AllGatherCommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE)
# Test finalize method
comm_impl.finalize(h_out,
reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2),
torch.tensor([1, 0, 1,
0]), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = MC2CommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE)
# Test finalize method
comm_impl.finalize(h_out,
reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
)
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = AlltoAllCommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
hidden_states=torch.randn(6, 8),
group_list=torch.tensor([2, 2, 2]),
group_list_type=1)
mock_td_instance.token_combine.return_value = TokenCombineResult(
routed_out=torch.randn(4, 8))
mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp
mock_unified_apply_mlp.return_value = torch.randn(6, 8)
# Create instance
comm_impl = AllGatherCommImpl(self.moe_config)
# Test fused_experts method
hidden_states = torch.randn(4, 8).contiguous()
w1 = torch.randn(16, 8).contiguous()
w2 = torch.randn(16, 8).contiguous()
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
# Make sure tensors are contiguous and have correct strides
hidden_states = hidden_states.contiguous()
w1 = w1.contiguous()
w2 = w2.contiguous()
result = comm_impl.fused_experts(hidden_states=hidden_states,
w1=[w1],
w2=[w2],
topk_weights=topk_weights,
topk_ids=topk_ids,
activation="silu")
# Verify result shape
self.assertEqual(result.routed_out.shape, (4, 8))
# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()
# Verify unified_apply_mlp was called
mock_unified_apply_mlp.assert_called_once()
# Verify token_combine was called
mock_td_instance.token_combine.assert_called_once()