### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
243 lines
9.7 KiB
Python
243 lines
9.7 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|
|
|
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
|
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
|
PrepareAndFinalizeWithMC2)
|
|
|
|
|
|
class TestPrepareAndFinalize(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Mock FusedMoEConfig
|
|
fake_stream = MagicMock()
|
|
patcher = patch("torch.npu.Stream", return_value=fake_stream)
|
|
patcher.start()
|
|
self.addCleanup(patcher.stop)
|
|
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
|
self.moe_config.tp_group = MagicMock()
|
|
self.moe_config.tp_group.device_group = MagicMock()
|
|
self.moe_config.dp_size = 1
|
|
self.moe_config.tp_size = 1
|
|
self.moe_config.ep_size = 1
|
|
self.moe_config.dp_group = MagicMock()
|
|
self.moe_config.original_num_experts = 8
|
|
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
|
return_value=1)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
|
return_value=0)
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
|
mock_tp_size):
|
|
mock_context = MagicMock()
|
|
mock_context.mc2_mask = torch.tensor([1, 0, 1])
|
|
mock_context.padded_num_tokens = 4
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
|
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
|
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
|
h_out = prepare_output.hidden_states
|
|
r_out = prepare_output.router_logits
|
|
mask = prepare_output.mc2_mask
|
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
|
|
|
# Check padding and split
|
|
self.assertEqual(h_out.shape[0], 4)
|
|
self.assertEqual(r_out.shape[0], 4)
|
|
self.assertEqual(mask.tolist(), [1, 0, 1])
|
|
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
|
|
|
# Finalize
|
|
result = layer.finalize(h_out,
|
|
reduce_results=False,
|
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
|
self.assertEqual(result.shape[0], 3)
|
|
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
|
return_value=2)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
|
return_value=0)
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch("torch.distributed.all_gather")
|
|
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
|
mock_get_forward_context, mock_tp_rank,
|
|
mock_tp_size):
|
|
mock_context = MagicMock()
|
|
mock_context.mc2_mask = torch.tensor([1, 0, 1, 0])
|
|
mock_context.padded_num_tokens = 4
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
|
hidden_states = torch.randn(4, 8)
|
|
router_logits = torch.randn(4, 2)
|
|
|
|
prepare_output = layer.prepare(
|
|
hidden_states,
|
|
router_logits,
|
|
enable_shared_expert_dp=False,
|
|
replace_allreduce=False)
|
|
h_out = prepare_output.hidden_states
|
|
r_out = prepare_output.router_logits
|
|
mask = prepare_output.mc2_mask
|
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
|
|
|
# With TP=2, should split into 2 parts
|
|
self.assertEqual(h_out.shape[0], 2)
|
|
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
|
|
|
# Mock all_gather behavior
|
|
def mock_all_gather_func(tensor_list, tensor, group=None):
|
|
tensor_list[0] = tensor
|
|
tensor_list[1] = tensor.clone()
|
|
|
|
mock_all_gather.side_effect = mock_all_gather_func
|
|
|
|
layer.split_hidden_states = [
|
|
torch.zeros_like(h_out),
|
|
torch.zeros_like(h_out)
|
|
]
|
|
final_result = layer.finalize(h_out,
|
|
reduce_results=False,
|
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
|
|
|
# Should concat back to original size
|
|
self.assertEqual(final_result.shape[0], 4)
|
|
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
|
return_value=1)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
|
return_value=0)
|
|
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
|
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
|
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
|
h_out = prepare_output.hidden_states
|
|
r_out = prepare_output.router_logits
|
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
|
|
|
# Pad to tp_size=1, so no change
|
|
self.assertEqual(h_out.shape[0], 3)
|
|
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
|
|
|
|
result = layer.finalize(h_out,
|
|
reduce_results=False,
|
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
|
self.assertEqual(result.shape[0], 3)
|
|
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
|
return_value=2)
|
|
@patch(
|
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
|
return_value=0)
|
|
@patch("torch.distributed.all_gather")
|
|
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
|
mock_tp_size):
|
|
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
|
hidden_states = torch.randn(2, 8)
|
|
router_logits = torch.randn(2, 2)
|
|
|
|
prepare_output = layer.prepare(
|
|
hidden_states,
|
|
router_logits,
|
|
enable_shared_expert_dp=False,
|
|
replace_allreduce=False)
|
|
h_out = prepare_output.hidden_states
|
|
r_out = prepare_output.router_logits
|
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
|
|
|
# Split due to TP=2
|
|
self.assertEqual(h_out.shape[0], 1)
|
|
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
|
|
|
|
# Mock all_gather
|
|
def mock_all_gather_func(tensor_list, tensor, group=None):
|
|
tensor_list[0] = tensor
|
|
tensor_list[1] = tensor.clone()
|
|
|
|
mock_all_gather.side_effect = mock_all_gather_func
|
|
|
|
layer.split_hidden_states = [
|
|
torch.zeros_like(h_out),
|
|
torch.zeros_like(h_out)
|
|
]
|
|
final_result = layer.finalize(h_out,
|
|
reduce_results=False,
|
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
|
|
|
# Should concat back
|
|
self.assertEqual(final_result.shape[0], 2)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
|
return_value=False)
|
|
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
|
mock_get_forward_context,
|
|
mock_get_dp_group):
|
|
# Mock forward context
|
|
mock_context = MagicMock()
|
|
mock_context.max_tokens_across_dp = 6
|
|
mock_get_forward_context.return_value = mock_context
|
|
|
|
# Create a proper mock for DP group with working all_gather
|
|
mock_dp_group = MagicMock()
|
|
|
|
def mock_all_gather_func(tensor, dim):
|
|
# Simulate DP=2: repeat the tensor along the specified dimension
|
|
return torch.cat([tensor, tensor], dim=dim)
|
|
|
|
mock_dp_group.all_gather = mock_all_gather_func
|
|
mock_get_dp_group.return_value = mock_dp_group
|
|
|
|
self.moe_config.dp_size = 2
|
|
self.moe_config.tp_size = 1
|
|
self.moe_config.ep_size = 1
|
|
self.moe_config.dp_group = mock_dp_group
|
|
|
|
layer = PrepareAndFinalizeWithAllGather(self.moe_config)
|
|
|
|
hidden_states = torch.randn(3, 8)
|
|
router_logits = torch.randn(3, 2)
|
|
|
|
prepare_output = layer.prepare(hidden_states, router_logits)
|
|
h_out = prepare_output.hidden_states
|
|
r_out = prepare_output.router_logits
|
|
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
|
|
|
# After all-gather with DP=2, should double the batch size
|
|
self.assertEqual(h_out.shape[0], 12)
|
|
self.assertEqual(r_out.shape[0], 12)
|
|
self.assertIsNone(padded_hidden_states_shape)
|
|
|
|
# Finalize with reduce_scatter
|
|
def mock_reduce_scatter_func(tensor, dim):
|
|
# Simulate reduce_scatter: take first half
|
|
return tensor[:3]
|
|
|
|
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
|
result = layer.finalize(h_out,
|
|
reduce_results=False,
|
|
padded_hidden_states_shape=padded_hidden_states_shape)
|
|
|
|
self.assertEqual(result.shape[0], 3)
|
|
|
|
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
|
self.assertEqual(result_with_tp.shape[0], 3)
|