#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.
#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.
2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.
#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
228 lines
9.6 KiB
Python
228 lines
9.6 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
|
AlltoAllCommImpl,
|
|
MC2CommImpl)
|
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
|
TokenDispatchResult)
|
|
|
|
|
|
class TestMoECommMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
# Mock FusedMoEConfig
|
|
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
|
self.moe_config.num_experts = 8
|
|
self.moe_config.num_local_experts = 2
|
|
self.moe_config.experts_per_token = 2
|
|
self.moe_config.tp_group = MagicMock()
|
|
self.moe_config.tp_group.device_group = MagicMock()
|
|
self.moe_config.dp_size = 1
|
|
self.moe_config.tp_size = 1
|
|
self.moe_config.ep_size = 1
|
|
self.moe_config.dp_group = MagicMock()
|
|
self.moe_config.global_redundant_expert_num = 0
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
|
)
|
|
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
|
mock_prepare_finalize,
|
|
mock_get_forward_context):
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "all_gather"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None, None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = AllGatherCommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
# Test finalize method
|
|
comm_impl.finalize(h_out,
|
|
reduce_results=True,
|
|
context_metadata=context_metadata)
|
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
|
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
|
mock_get_forward_context):
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "mc2"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2),
|
|
torch.tensor([1, 0, 1,
|
|
0]), None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = MC2CommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
# Test finalize method
|
|
comm_impl.finalize(h_out,
|
|
reduce_results=True,
|
|
context_metadata=context_metadata)
|
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
|
|
)
|
|
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
|
mock_prepare_finalize,
|
|
mock_get_forward_context):
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "alltoall"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None, None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Create instance
|
|
comm_impl = AlltoAllCommImpl(self.moe_config)
|
|
|
|
# Test prepare method
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
|
hidden_states, router_logits)
|
|
|
|
# Verify prepare was called with correct arguments
|
|
mock_pf_instance.prepare.assert_called_once_with(
|
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
|
)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
|
)
|
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
|
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
|
mock_token_dispatcher, mock_prepare_finalize,
|
|
mock_get_forward_context):
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.moe_comm_method = "all_gather"
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Mock prepare finalize
|
|
mock_pf_instance = MagicMock()
|
|
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
|
torch.randn(4, 2), None)
|
|
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
|
mock_prepare_finalize.return_value = mock_pf_instance
|
|
|
|
# Mock token dispatcher
|
|
mock_td_instance = MagicMock()
|
|
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
|
hidden_states=torch.randn(6, 8),
|
|
group_list=torch.tensor([2, 2, 2]),
|
|
group_list_type=1)
|
|
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
|
routed_out=torch.randn(4, 8))
|
|
mock_token_dispatcher.return_value = mock_td_instance
|
|
|
|
# Mock unified_apply_mlp
|
|
mock_unified_apply_mlp.return_value = torch.randn(6, 8)
|
|
|
|
# Create instance
|
|
comm_impl = AllGatherCommImpl(self.moe_config)
|
|
|
|
# Test fused_experts method
|
|
hidden_states = torch.randn(4, 8).contiguous()
|
|
w1 = torch.randn(16, 8).contiguous()
|
|
w2 = torch.randn(16, 8).contiguous()
|
|
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
|
[0.6, 0.4]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
|
|
|
# Make sure tensors are contiguous and have correct strides
|
|
hidden_states = hidden_states.contiguous()
|
|
w1 = w1.contiguous()
|
|
w2 = w2.contiguous()
|
|
|
|
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
|
w1=[w1],
|
|
w2=[w2],
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation="silu")
|
|
|
|
# Verify result shape
|
|
self.assertEqual(result.routed_out.shape, (4, 8))
|
|
|
|
# Verify token_dispatch was called
|
|
mock_td_instance.token_dispatch.assert_called_once()
|
|
|
|
# Verify unified_apply_mlp was called
|
|
mock_unified_apply_mlp.assert_called_once()
|
|
|
|
# Verify token_combine was called
|
|
mock_td_instance.token_combine.assert_called_once()
|