init v0.11.0rc0
This commit is contained in:
289
tests/ut/ops/test_fused_moe_prepare_and_finalize.py
Normal file
289
tests/ut/ops/test_fused_moe_prepare_and_finalize.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
self.moe_config.tp_group = MagicMock()
|
||||
self.moe_config.tp_group.device_group = MagicMock()
|
||||
self.moe_config.dp_size = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
mock_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
self.assertEqual(r_out.shape[0], 4)
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
mock_context.mc2_mask = torch.tensor([1, 0, 1, 0])
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
|
||||
# Mock all_gather behavior
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func
|
||||
|
||||
layer.split_hidden_states = [
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
|
||||
# Mock all_gather
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func
|
||||
|
||||
layer.split_hidden_states = [
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_allgather_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.max_tokens_across_dp = 6
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Create a proper mock for DP group with working all_gather
|
||||
mock_dp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
# Simulate DP=2: repeat the tensor along the specified dimension
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dp_group.return_value = mock_dp_group
|
||||
|
||||
self.moe_config.dp_size = 2
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = mock_dp_group
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock the gate function for rm_router_logits=False case
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
self.assertEqual(r_out.shape[0], 12)
|
||||
|
||||
# Finalize with reduce_scatter
|
||||
def mock_reduce_scatter_func(tensor, dim):
|
||||
# Simulate reduce_scatter: take first half
|
||||
return tensor[:3]
|
||||
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
# Test with TP all-reduce
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
# Mock forward context with DP metadata
|
||||
mock_context = MagicMock()
|
||||
if vllm_version_is("0.10.2"):
|
||||
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
|
||||
[2, 5, 7])
|
||||
else:
|
||||
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
|
||||
[2, 5, 7])
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Setup DP group mock
|
||||
mock_dp_group = MagicMock()
|
||||
mock_dp_group.broadcast = MagicMock()
|
||||
mock_dp_group.all_reduce = MagicMock()
|
||||
mock_get_dp_group.return_value = mock_dp_group
|
||||
|
||||
# Mock all_reduce to just return input (simulate sum)
|
||||
def mock_all_reduce(tensor):
|
||||
return tensor * 2
|
||||
|
||||
mock_dp_group.all_reduce.side_effect = mock_all_reduce
|
||||
|
||||
# Setup config
|
||||
self.moe_config.dp_size = 3
|
||||
self.moe_config.dp_rank = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock gate for router logits recomputation
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
self.assertEqual(r_out.shape, (7, 2))
|
||||
|
||||
# Run finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should slice back to local: [3, 8]
|
||||
self.assertEqual(result.shape, (3, 8))
|
||||
|
||||
# Test with reduce_results=True and TP/EP > 1
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape, (3, 8))
|
||||
Reference in New Issue
Block a user