init v0.11.0rc0
This commit is contained in:
@@ -20,10 +20,10 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.row_idx, expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
1) # group_list_type == 1
|
||||
0) # group_list_type == 0
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
self.patcher_npu_moe_init_routing_v2 = patch(
|
||||
'torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
||||
)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
self.patcher_npu_moe_init_routing_v2.stop()
|
||||
self.patcher_npu_moe_token_unpermute.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
@@ -207,12 +200,27 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_without_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
@@ -230,7 +238,33 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.row_idx,
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(results["hidden_states"])
|
||||
self.assertIsNotNone(results["group_list"])
|
||||
self.assertIsNotNone(results["dynamic_scale"])
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -242,9 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
@@ -260,10 +292,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
@@ -315,7 +347,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -323,7 +355,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
@@ -488,119 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
|
||||
Reference in New Issue
Block a user