forked from EngineX-Ascend/enginex-ascend-910-vllm
v0.10.1rc1
This commit is contained in:
17
tests/ut/ops/expert_map.json
Normal file
17
tests/ut/ops/expert_map.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"moe_layer_count":
|
||||
1,
|
||||
"layer_list": [{
|
||||
"layer_id":
|
||||
0,
|
||||
"device_count":
|
||||
2,
|
||||
"device_list": [{
|
||||
"device_id": 0,
|
||||
"device_expert": [7, 2, 0, 3, 5]
|
||||
}, {
|
||||
"device_id": 1,
|
||||
"device_expert": [6, 1, 4, 7, 2]
|
||||
}]
|
||||
}]
|
||||
}
|
||||
61
tests/ut/ops/test_activation.py
Normal file
61
tests/ut/ops/test_activation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1)
|
||||
def test_QuickGELU_forward(mock_gelu, dummy_tensor):
|
||||
layer = QuickGELU()
|
||||
out = layer.forward(dummy_tensor)
|
||||
|
||||
expected_out = dummy_tensor + 1
|
||||
assert torch.allclose(out, expected_out)
|
||||
|
||||
mock_gelu.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
|
||||
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = SiluAndMul()
|
||||
out = layer.forward(dummy_tensor)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg = dummy_tensor.to(torch.float32)
|
||||
else:
|
||||
expected_arg = dummy_tensor
|
||||
|
||||
# assert mock_swiglu.call_count == 1
|
||||
mock_swiglu.assert_called_once()
|
||||
|
||||
actual_arg = mock_swiglu.call_args[0][0]
|
||||
assert torch.allclose(
|
||||
actual_arg,
|
||||
expected_arg), "npu_swiglu called with unexpected input"
|
||||
|
||||
expected_out = dummy_tensor + 1
|
||||
assert torch.allclose(out, expected_out)
|
||||
69
tests/ut/ops/test_common_fused_moe.py
Normal file
69
tests/ut/ops/test_common_fused_moe.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
|
||||
def test_fused_experts_moge(self):
|
||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
||||
]
|
||||
|
||||
mock_swiglu.side_effect = lambda x: x
|
||||
|
||||
hidden_states = torch.randn(4, 128)
|
||||
w1 = torch.randn(4, 256, 128)
|
||||
w2 = torch.randn(4, 128, 128)
|
||||
topk_weights = torch.rand(4, 1)
|
||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
||||
top_k = 1
|
||||
global_num_experts = 4
|
||||
|
||||
moe_parallel_config = type(
|
||||
'MockConfig', (), {
|
||||
'ep_size': 1,
|
||||
'tp_size': 1,
|
||||
'dp_size': 1,
|
||||
'tp_rank': 0,
|
||||
'dp_rank': 0,
|
||||
'ep_rank': 0,
|
||||
'use_ep': True
|
||||
})()
|
||||
|
||||
output = fused_experts_moge(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
141
tests/ut/ops/test_expert_load_balancer.py
Normal file
141
tests/ut/ops/test_expert_load_balancer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List, TypedDict
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class TestExpertLoadBalancer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
json_file = _TEST_DIR + "/expert_map.json"
|
||||
with open(json_file, 'r') as f:
|
||||
self.expert_map: MockData = json.load(f)
|
||||
|
||||
self.expert_load_balancer = ExpertLoadBalancer(json_file,
|
||||
global_expert_num=8)
|
||||
|
||||
def test_init(self):
|
||||
|
||||
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
|
||||
torch.Tensor)
|
||||
self.assertEqual(self.expert_load_balancer.layers_num,
|
||||
self.expert_map["moe_layer_count"])
|
||||
self.assertEqual(self.expert_load_balancer.ranks_num,
|
||||
self.expert_map["layer_list"][0]["device_count"])
|
||||
|
||||
def test_generate_index_dicts(self):
|
||||
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
|
||||
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
|
||||
expected_result = [{
|
||||
7: 0,
|
||||
2: 1,
|
||||
0: 2,
|
||||
3: 3,
|
||||
5: 4
|
||||
}, {
|
||||
6: 5,
|
||||
1: 6,
|
||||
4: 7,
|
||||
7: 8,
|
||||
2: 9
|
||||
}]
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_expert_placement_map(self):
|
||||
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
|
||||
)
|
||||
self.assertEqual(expert_placement_map.shape,
|
||||
(self.expert_load_balancer.layers_num,
|
||||
self.expert_load_balancer.ranks_num, 8))
|
||||
self.assertTrue(torch.all(expert_placement_map >= -1))
|
||||
|
||||
def test_generate_log2phy_expert_map(self):
|
||||
layer_id = 0
|
||||
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
||||
layer_id)
|
||||
self.assertEqual(log2phy_map.shape,
|
||||
(self.expert_load_balancer.ranks_num, 8))
|
||||
self.assertTrue(torch.all(log2phy_map >= -1))
|
||||
|
||||
@mock.patch("torch_npu.npu._lazy_init")
|
||||
@mock.patch("torch.npu.current_device", return_value="cpu")
|
||||
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
|
||||
layer_id = 0
|
||||
rank_id = 0
|
||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||
layer_id, rank_id)
|
||||
self.assertEqual(rank_local_expert_num, 5)
|
||||
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
|
||||
dtype=torch.int32).to(
|
||||
rank_expert_map.device)
|
||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||
|
||||
rank_id = 1
|
||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
|
||||
dtype=torch.int32).to(
|
||||
rank_expert_map.device)
|
||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||
|
||||
def test_get_rank_log2phy_map(self):
|
||||
layer_id = 0
|
||||
rank_id = 0
|
||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
|
||||
dtype=torch.int32).to(
|
||||
log2phy_map.device)
|
||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||
|
||||
rank_id = 1
|
||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
|
||||
dtype=torch.int32).to(
|
||||
log2phy_map.device)
|
||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||
|
||||
def test_get_global_redundant_expert_num(self):
|
||||
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
|
||||
)
|
||||
expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \
|
||||
self.expert_map["layer_list"][0]["device_count"] - 8
|
||||
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)
|
||||
741
tests/ut/ops/test_fused_ops.py
Normal file
741
tests/ut/ops/test_fused_ops.py
Normal file
@@ -0,0 +1,741 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
|
||||
def mock_ep_and_mc2_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.rank = 0
|
||||
mock_group.world_size = 4
|
||||
mock_group.device_group = "mock_group_ep"
|
||||
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_dp_and_tp_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
|
||||
captured_dispatchers = {}
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
None
|
||||
)), \
|
||||
patch('torch_npu.npu_moe_init_routing', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
||||
torch.randn(8, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_grouped_matmul", return_value=(
|
||||
[torch.randn(16, 2)]
|
||||
)), \
|
||||
patch("torch_npu.npu_swiglu", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)):
|
||||
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||
torch.randn(16, 2))), \
|
||||
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||
torch.randn(16, 2))):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_moe_config():
|
||||
return {
|
||||
'num_experts': 8,
|
||||
'top_k': 2,
|
||||
'hidden_size': 512,
|
||||
'intermediate_size': 1024
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
return AscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
|
||||
def __init__(self, shared_experts, num_tokens):
|
||||
super().__init__()
|
||||
if shared_experts:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
||||
torch.randn(num_tokens, 10)))
|
||||
else:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class TestAscendFusedMoe:
|
||||
|
||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
||||
layer = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
layer.w13_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['intermediate_size'] * 2,
|
||||
default_moe_config['hidden_size']))
|
||||
layer.w2_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['hidden_size'],
|
||||
default_moe_config['intermediate_size']))
|
||||
|
||||
assert layer.num_experts == default_moe_config['num_experts']
|
||||
assert layer.top_k == default_moe_config['top_k']
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
|
||||
moe = AscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert moe.quant_method == mock_quant_method
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"others_param",
|
||||
[[None,
|
||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
if ep_size == 1:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
if shared_experts:
|
||||
assert output[0].shape == (num_tokens, 32)
|
||||
assert output[1].shape == (num_tokens, 10)
|
||||
else:
|
||||
assert output.shape == (num_tokens, 32)
|
||||
|
||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
||||
default_moe_config):
|
||||
inputs = torch.randn(5, 32)
|
||||
router_logits = torch.randn(5, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
moe.quant_method = MockQuantMethod(None, 5)
|
||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
||||
router_logits,
|
||||
is_prefill=False,
|
||||
real_top_k=1)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
assert output.shape == (5, 32)
|
||||
|
||||
|
||||
class TestAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
||||
layer = MagicMock()
|
||||
layer.w13_weight.data = torch.randn(16, 32)
|
||||
layer.w2_weight.data = torch.randn(16, 32)
|
||||
|
||||
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
|
||||
patch('vllm_ascend.utils.is_310p', return_value=False):
|
||||
moe_method.process_weights_after_loading(layer)
|
||||
|
||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
||||
assert not layer.w13_weight.requires_grad
|
||||
assert not layer.w2_weight.requires_grad
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||
global_num_experts):
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
topk_group=None,
|
||||
num_expert_group=None,
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
assert topk_weights.shape == (8, 2)
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 20),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_dequant.return_value = (torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16),
|
||||
torch.randn(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.float32)
|
||||
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
|
||||
mock_npu_dequant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization(self,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p):
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.float16)
|
||||
], [torch.randn(10, 20, dtype=torch.float16)]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.bfloat16)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
mock_npu_dynamic_quant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization_310p(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p):
|
||||
mock_is_310p.return_value = True
|
||||
|
||||
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
|
||||
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
|
||||
[mock_gmm2_out]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
mock_is_310p.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
53
tests/ut/ops/test_layernorm.py
Normal file
53
tests/ut/ops/test_layernorm.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
|
||||
|
||||
def mock_rms_norm(x, weight, eps):
|
||||
return x + 1, None
|
||||
|
||||
|
||||
def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@pytest.mark.parametrize("residual",
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
|
||||
residual, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = RMSNorm(hidden_size=32, eps=1e-05)
|
||||
if residual is not None:
|
||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
out_x = layer.forward(dummy_tensor, residual)
|
||||
expected_out_x = dummy_tensor + 1
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
363
tests/ut/ops/test_linear.py
Normal file
363
tests/ut/ops/test_linear.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear, LinearBase,
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class TestAscendMlpRowParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
self.tensor_parallel_world_size = 2
|
||||
self.tensor_parallel_rank = 0
|
||||
self.mlp_tensor_parallel_world_size = 2
|
||||
self.mlp_tensor_parallel_rank = 1
|
||||
|
||||
self.get_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
|
||||
return_value=self.tensor_parallel_world_size)
|
||||
self.get_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
|
||||
return_value=self.tensor_parallel_rank)
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
|
||||
return_value=self.mlp_tensor_parallel_world_size)
|
||||
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
|
||||
return_value=self.mlp_tensor_parallel_rank)
|
||||
|
||||
self.get_tensor_model_parallel_world_size_mock = \
|
||||
self.get_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_tensor_model_parallel_rank_mock = \
|
||||
self.get_tensor_model_parallel_rank_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_world_size_mock = \
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_rank_mock = \
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.start()
|
||||
|
||||
self.split_tensor_along_last_dim_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
|
||||
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
|
||||
self.tensor_model_parallel_all_reduce_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_reduce_mock = \
|
||||
self.tensor_model_parallel_all_reduce_patch.start()
|
||||
self.split_tensor_along_last_dim_mock = \
|
||||
self.split_tensor_along_last_dim_patch.start()
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
|
||||
mock.MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.get_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_tensor_model_parallel_rank_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.stop()
|
||||
self.split_tensor_along_last_dim_patch.stop()
|
||||
self.tensor_model_parallel_all_reduce_patch.stop()
|
||||
self.get_mlp_tp_group_patch.stop()
|
||||
|
||||
def test_init_with_down_proj_prefix(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj")
|
||||
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
|
||||
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
|
||||
def test_forward_with_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
|
||||
def test_forward_without_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="other",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
|
||||
|
||||
def test_skip_bias_add(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
self.assertIsNotNone(bias)
|
||||
|
||||
def test_no_reduce_results(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
|
||||
|
||||
def test_input_not_parallel(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
input_is_parallel=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once()
|
||||
|
||||
def test_exception_when_reduce_false_and_bias(self):
|
||||
with self.assertRaises(ValueError):
|
||||
AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=True,
|
||||
skip_bias_add=False)
|
||||
|
||||
|
||||
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock distributed functions
|
||||
self.mlp_tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
|
||||
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
|
||||
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
|
||||
|
||||
self.mlp_tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
|
||||
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
|
||||
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
|
||||
|
||||
self.tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
|
||||
self.tp_size_mock = self.tp_size_patch.start()
|
||||
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
|
||||
|
||||
self.tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
|
||||
self.tp_rank_mock = self.tp_rank_patch.start()
|
||||
self.tp_rank_mock.return_value = 1 # Current GPU rank
|
||||
|
||||
# Mock divide function (assumed to be in your module)
|
||||
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
|
||||
self.divide_mock = self.divide_patch.start()
|
||||
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
|
||||
|
||||
# Mock QuantizationConfig and QuantMethod
|
||||
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
|
||||
|
||||
# Mock LinearBase initialization
|
||||
self.linear_base_init_patch = mock.patch.object(
|
||||
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
|
||||
self.linear_base_init_patch.start()
|
||||
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
|
||||
def mock_linear_base_init(self, instance, *args, **kwargs):
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.params_dtype = mock.MagicMock()
|
||||
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.output_size_per_partition = 4
|
||||
instance.params_dtype = torch.float32
|
||||
|
||||
def tearDown(self):
|
||||
self.mlp_tp_size_patch.stop()
|
||||
self.mlp_tp_rank_patch.stop()
|
||||
self.tp_size_patch.stop()
|
||||
self.tp_rank_patch.stop()
|
||||
self.divide_patch.stop()
|
||||
self.linear_base_init_patch.stop()
|
||||
|
||||
def test_mlp_optimize_initialization(self):
|
||||
# Test when prefix contains "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.gate_up_proj",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify MLP optimization flags
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 2)
|
||||
self.assertEqual(layer.tp_rank, 0)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_regular_parallel_initialization(self):
|
||||
# Test when prefix does NOT contain "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.q_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify regular TP flags
|
||||
self.assertFalse(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 4)
|
||||
self.assertEqual(layer.tp_rank, 1)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_output_sizes_handling(self):
|
||||
# Test when output_sizes is provided
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
output_sizes=[4, 4],
|
||||
prefix="model.layers.0.qkv_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify output_partition_sizes
|
||||
self.assertEqual(layer.output_partition_sizes, [2])
|
||||
|
||||
|
||||
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
|
||||
self.mlp_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
|
||||
self.tensor_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
|
||||
self.mlp_world_size_patch.start()
|
||||
self.tensor_world_size_patch.start()
|
||||
|
||||
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
|
||||
self.mlp_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
|
||||
self.tensor_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
|
||||
self.mlp_rank_patch.start()
|
||||
self.tensor_rank_patch.start()
|
||||
|
||||
# Mock all_gather methods
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
|
||||
self.tensor_model_parallel_all_gather_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_gather_mock = \
|
||||
self.tensor_model_parallel_all_gather_patch.start()
|
||||
|
||||
# Mock AscendMlpColumnParallelLinear's __init__
|
||||
self.linear_init_patch = mock.patch.object(
|
||||
AscendMlpColumnParallelLinear,
|
||||
"__init__",
|
||||
side_effect=self.mock_linear_init)
|
||||
self.linear_init_patch.start()
|
||||
|
||||
# Create mock objects
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
self.apply_output = torch.randn(2, 8)
|
||||
|
||||
self.quant_method_mock.apply.return_value = self.apply_output
|
||||
|
||||
def mock_linear_init(self, instance, *args, **kwargs):
|
||||
torch.nn.Module.__init__(instance)
|
||||
# Set quant_method and other attributes
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.gather_output = False
|
||||
instance.skip_bias_add = False
|
||||
instance.return_bias = True
|
||||
|
||||
def test_forward_with_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def test_forward_without_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix not containing "gate_up_proj"
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.quant_method_mock.apply.assert_called_once_with(
|
||||
layer, input_tensor, layer.bias)
|
||||
self.tensor_model_parallel_all_gather_mock.assert_not_called()
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def tearDown(self):
|
||||
self.linear_init_patch.stop()
|
||||
self.mlp_world_size_patch.stop()
|
||||
self.tensor_world_size_patch.stop()
|
||||
self.mlp_rank_patch.stop()
|
||||
self.tensor_rank_patch.stop()
|
||||
self.get_mlp_tp_group_mock.stop()
|
||||
self.tensor_model_parallel_all_gather_mock.stop()
|
||||
318
tests/ut/ops/test_rotary_embedding.py
Normal file
318
tests/ut/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = _custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.cos_sin_cache = torch.randn(3, 1, 32)
|
||||
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
|
||||
self.max_position, self.rope_theta,
|
||||
self.is_neox_style, torch.float16)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = self.is_neox_style
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock__c.rotary_embedding.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
non_contig_query,
|
||||
non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.layer.forward(self.positions, self.query, self.key, offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.scaling_factor = 1
|
||||
self.layer = None
|
||||
|
||||
def _create_layer(self):
|
||||
self.layer = DeepseekScalingRotaryEmbedding(
|
||||
self.head_size, self.rotary_dim, self.max_position,
|
||||
self.rope_theta, self.is_neox_style, self.scaling_factor,
|
||||
torch.float16)
|
||||
return self.layer
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
|
||||
return_value=(self.query,
|
||||
self.key)) as mock_rope_forward_oot:
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
self.layer.max_seq_len = 1024
|
||||
# Test cache situation is true
|
||||
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
max_seq_len=2048)
|
||||
mock_set_cache.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
key = torch.randn(1, 32)
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
|
||||
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_basic_case(self, mock_npuplatform):
|
||||
# Test with standard values
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_yarn_get_mscale(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
# test_scale_less_than_or_equal_1
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
# test_scale_greater_than_1:
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = self.layer._yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
606
tests/ut/ops/test_token_dispatcher.py
Normal file
606
tests/ut/ops/test_token_dispatcher.py
Normal file
@@ -0,0 +1,606 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mc2_group = MagicMock()
|
||||
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
self.rank_group_patch = patch("torch.distributed.get_rank",
|
||||
return_value=0)
|
||||
self.rank_group_patch.start()
|
||||
|
||||
# Mock get_forward_context().mc2_mask
|
||||
self.forward_context = MagicMock()
|
||||
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
self.forward_context_patch = patch(
|
||||
"vllm.forward_context.get_forward_context",
|
||||
return_value=self.forward_context)
|
||||
self.forward_context_patch.start()
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
self.forward_context_patch.stop()
|
||||
self.ascend_soc_version_patch.stop()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
self.assertTrue(self.dispatcher.a3_need_extra_args)
|
||||
|
||||
def test_get_dispatch_mc2_kwargs_without_quant(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
|
||||
def test_token_permutation_dispatch(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_weights = torch.randn(10, 1)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
1) # group_list_type == 1
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
self.row_idx,
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
self.dispatcher.shared_experts = MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock dependencies
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
"with_quant": False,
|
||||
}
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
self.assertEqual(results["hidden_states"].shape, (6, 128))
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Patch properties
|
||||
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_group',
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock())
|
||||
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_rank',
|
||||
new_callable=PropertyMock,
|
||||
return_value=0)
|
||||
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_size',
|
||||
new_callable=PropertyMock,
|
||||
return_value=2)
|
||||
|
||||
self.addCleanup(patcher1.stop)
|
||||
self.addCleanup(patcher2.stop)
|
||||
self.addCleanup(patcher3.stop)
|
||||
|
||||
self.mock_ep_group_prop = patcher1.start()
|
||||
self.mock_ep_rank_prop = patcher2.start()
|
||||
self.mock_ep_size_prop = patcher3.start()
|
||||
|
||||
# Mock torch_npu.npu_moe_token_permute
|
||||
patcher4 = patch('torch_npu.npu_moe_token_permute')
|
||||
self.mock_npu_moe_token_permute = patcher4.start()
|
||||
self.addCleanup(patcher4.stop)
|
||||
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
|
||||
torch.arange(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_token_unpermute
|
||||
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = patcher5.start()
|
||||
self.addCleanup(patcher5.stop)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
MagicMock())
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
|
||||
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
|
||||
|
||||
# Mock torch.histc
|
||||
patcher8 = patch('torch.histc')
|
||||
self.mock_histc = patcher8.start()
|
||||
self.addCleanup(patcher8.stop)
|
||||
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
|
||||
dtype=torch.int64)
|
||||
|
||||
# Mock torch.npu.current_device
|
||||
patcher9 = patch('torch.npu.current_device')
|
||||
self.mock_current_device = patcher9.start()
|
||||
self.addCleanup(patcher9.stop)
|
||||
self.mock_current_device.return_value = 'cpu'
|
||||
|
||||
# Mock torch_npu.npu_dynamic_quant
|
||||
patcher10 = patch('torch_npu.npu_dynamic_quant')
|
||||
self.mock_npu_dynamic_quant = patcher10.start()
|
||||
self.addCleanup(patcher10.stop)
|
||||
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
|
||||
torch.randn(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_init_routing_v2
|
||||
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = patcher11.start()
|
||||
self.addCleanup(patcher11.stop)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
|
||||
16, 16), torch.arange(16), None, torch.randn(16))
|
||||
|
||||
# Mock torch.repeat_interleave
|
||||
patcher12 = patch('torch.repeat_interleave')
|
||||
self.mock_repeat_interleave = patcher12.start()
|
||||
self.addCleanup(patcher12.stop)
|
||||
self.mock_repeat_interleave.return_value = torch.arange(16)
|
||||
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
|
||||
8)
|
||||
self.dispatcher.topk_weights = torch.rand(8, 4)
|
||||
self.dispatcher.input_splits = [4, 4]
|
||||
self.dispatcher.output_splits = [4, 4]
|
||||
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
|
||||
16)
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
|
||||
[[2, 2], [2, 2]], dtype=torch.int64)
|
||||
|
||||
expert_output = torch.randn(16, 16)
|
||||
output = self.dispatcher.token_combine(expert_output)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2)
|
||||
|
||||
self.mock_repeat_interleave.return_value = torch.tensor(
|
||||
[], dtype=torch.long)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
log2phy = torch.tensor([1, 0, 3, 2])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
232
tests/ut/ops/test_vocab_parallel_embedding.py
Normal file
232
tests/ut/ops/test_vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/lora/test_layers.py
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
||||
|
||||
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
|
||||
|
||||
|
||||
class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.num_embeddings = 50
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
|
||||
def _create_layer(self):
|
||||
# Patch methods and dependencies for VocabParallelEmbedding
|
||||
mock_group = MagicMock()
|
||||
mock_group.world_size = 2
|
||||
mock_group.rank_in_group = 0
|
||||
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
|
||||
|
||||
# Create an instance of VocabParallelEmbedding
|
||||
layer = AscendVocabParallelEmbedding(
|
||||
num_embeddings=self.num_embeddings,
|
||||
embedding_dim=self.embedding_dim,
|
||||
org_num_embeddings=self.org_num_embeddings,
|
||||
padding_size=self.padding_size,
|
||||
quant_config=None, # Mock quantization config
|
||||
prefix="")
|
||||
|
||||
layer.shard_indices = MagicMock()
|
||||
layer.shard_indices.org_vocab_start_index = 10
|
||||
layer.shard_indices.org_vocab_end_index = 20
|
||||
layer.shard_indices.num_org_vocab_padding = 5
|
||||
layer.shard_indices.added_vocab_start_index = 30
|
||||
layer.shard_indices.added_vocab_end_index = 40
|
||||
|
||||
# Mock the quantization method
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
side_effect=lambda _, x: torch.randn(x.shape[0], self.
|
||||
embedding_dim))
|
||||
return layer
|
||||
|
||||
def test_get_masked_input_and_mask(self):
|
||||
"""Test the mask and offset calculation helper function."""
|
||||
layer = self._create_layer()
|
||||
|
||||
input_ = torch.tensor([5, 15, 25, 35, 45])
|
||||
|
||||
masked_input, mask = layer._get_masked_input_and_mask(
|
||||
input_,
|
||||
org_vocab_start_index=10,
|
||||
org_vocab_end_index=20,
|
||||
num_org_vocab_padding=5,
|
||||
added_vocab_start_index=30,
|
||||
added_vocab_end_index=40)
|
||||
|
||||
expected_mask = torch.tensor([True, False, True, False, True])
|
||||
self.assertTrue(
|
||||
torch.equal(mask, expected_mask),
|
||||
f"Mask mismatch. Expected {expected_mask}, got {mask}")
|
||||
|
||||
expected_masked = torch.tensor([0, 5, 0, 20, 0])
|
||||
self.assertTrue(
|
||||
torch.equal(masked_input, expected_masked),
|
||||
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
|
||||
)
|
||||
|
||||
def test_forward_with_tp_size_1(self):
|
||||
"""Test forward pass without tensor parallelism."""
|
||||
# Create a fresh mock embedding with tp_size=1
|
||||
layer = self._create_layer()
|
||||
layer.tp_size = 1
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
return_value=torch.randn(3, layer.embedding_dim))
|
||||
|
||||
input_ = torch.tensor([1, 2, 3])
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Should just pass through without masking
|
||||
layer.quant_method.embedding.assert_called_once_with(
|
||||
layer, input_.long())
|
||||
self.assertEqual(output.shape, (3, layer.embedding_dim))
|
||||
|
||||
# Verify all_reduce was called once
|
||||
mock_reduce_tp1.assert_called_once()
|
||||
|
||||
def test_forward_with_tp(self):
|
||||
layer = self._create_layer()
|
||||
layer.tp_size = 2
|
||||
|
||||
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Check that masking was applied correctly
|
||||
layer.quant_method.embedding.assert_called_once()
|
||||
called_input = layer.quant_method.embedding.call_args[0][1]
|
||||
expected_input = torch.tensor([5, 20]) # after offset calculation
|
||||
self.assertTrue(torch.all(called_input == expected_input))
|
||||
|
||||
# Check that all reduce was called
|
||||
mock_reduce_tp.assert_called_once()
|
||||
self.assertEqual(output.shape, (2, self.embedding_dim))
|
||||
|
||||
def test_forward_with_invalid_vocab(self):
|
||||
"""Test that invalid vocab indices are properly masked out."""
|
||||
# Create a fresh embedding layer
|
||||
layer = self._create_layer()
|
||||
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
|
||||
# Create predictable mock output
|
||||
mock_output = torch.randn(5, self.embedding_dim)
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
return_value=mock_output.clone())
|
||||
|
||||
# Patch tensor_model_parallel_all_reduce to mock its behavior
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
# Check that invalid positions (0, 2, 4) were zeroed out
|
||||
self.assertTrue(torch.all(output[0] == 0))
|
||||
self.assertTrue(torch.all(output[2] == 0))
|
||||
self.assertTrue(torch.all(output[4] == 0))
|
||||
self.assertTrue(torch.all(output[1] == mock_output[1]))
|
||||
self.assertTrue(torch.all(output[3] == mock_output[3]))
|
||||
self.assertEqual(output.shape, (5, self.embedding_dim))
|
||||
|
||||
def test_output_shape(self):
|
||||
"""Test that output shape is correct."""
|
||||
# Create a fresh embedding layer
|
||||
layer = self._create_layer()
|
||||
|
||||
test_cases = [
|
||||
(torch.tensor([15]), (1, self.embedding_dim)),
|
||||
(torch.tensor([15, 35]), (2, self.embedding_dim)),
|
||||
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
|
||||
]
|
||||
|
||||
for input_, expected_shape in test_cases:
|
||||
with self.subTest(input=input_):
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
|
||||
class TestAscendLogitsProcessor(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.vocab_size = 50
|
||||
self.num_embeddings = 50
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
|
||||
self.mock_group = MagicMock()
|
||||
self.mock_group.world_size = 2
|
||||
self.mock_group.rank_in_group = 0
|
||||
self.mock_ascend_config = MagicMock()
|
||||
self.mock_quant_method = MagicMock()
|
||||
self.mock_quant_method.apply = MagicMock(
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
self.patches = [
|
||||
patch("vllm_ascend.ascend_config.get_ascend_config",
|
||||
return_value=self.mock_ascend_config),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
|
||||
return_value=True),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start()
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
p.stop()
|
||||
|
||||
def test_create_processor(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
self.assertEqual(processor.vocab_size, self.vocab_size)
|
||||
|
||||
def test_get_logits(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
|
||||
embedding_dim=self.embedding_dim,
|
||||
prefix="lm_head")
|
||||
lmhead.quant_method = self.mock_quant_method
|
||||
lmhead.quant_method.apply = self.mock_quant_method.apply
|
||||
hidden_state = torch.randn(1, self.org_num_embeddings)
|
||||
processor._get_logits(hidden_state, lmhead)
|
||||
self.mock_quant_method.apply.assert_called_once()
|
||||
Reference in New Issue
Block a user