init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
@pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
side_effect=lambda x: None)
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done, mock_swiglu,
is_310p_return, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul()
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
else:
expected_arg = dummy_tensor
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(
actual_arg,

View File

@@ -0,0 +1,98 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import pytest
import torch
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.ops.moe.comm_utils import (
_gather_along_first_dim, async_all_to_all,
gather_from_sequence_parallel_region)
class TestDistributedCommunication(PytestBase):
@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.npu.current_device", return_value="cpu")
mocker.patch("torch.distributed.get_world_size", return_value=4)
mocker.patch("torch.distributed.get_rank", return_value=0)
@pytest.mark.parametrize(
"input_tensor, output_split_sizes, input_split_sizes",
[(torch.randn(8, 16), [2, 2, 2, 2], [2, 2, 2, 2]),
(torch.randn(16, 32), None, None)])
def test_async_all_to_all(self, input_tensor, output_split_sizes,
input_split_sizes, mocker: MockerFixture):
"""Test async_all_to_all"""
mock_group = mocker.MagicMock()
mocker.patch("torch.distributed.all_to_all_single",
return_value=mocker.MagicMock())
_, a2a_out, handle = async_all_to_all(input_tensor, output_split_sizes,
input_split_sizes, mock_group)
# Check if the output tensor is created properly
if output_split_sizes is None:
assert a2a_out.shape == input_tensor.shape
else:
total_output_size = sum(output_split_sizes)
expected_shape = [total_output_size] + list(
input_tensor.size())[1:]
assert a2a_out.shape == torch.Size(expected_shape)
# Ensure handle is returned from async operation
assert handle is not None
assert isinstance(handle, mocker.MagicMock)
@pytest.mark.parametrize("world_size, test_tensor, expected",
[(1, torch.randn(8, 16), (8, 16)),
(4, torch.randn(8, 16), (32, 16))])
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
mocker: MockerFixture):
"""Test _gather_along_first_dim"""
mocker.patch("torch.distributed.get_world_size",
return_value=world_size)
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
assert result.shape == expected
@pytest.mark.parametrize("input_tensor, output_split_sizes",
[(torch.randn(8, 16), None),
(torch.randn(8, 16), [2, 2, 2, 2])])
def test_gather_from_sequence_parallel_region(self, input_tensor,
output_split_sizes,
mocker: MockerFixture):
"""Test gather_from_sequence_parallel_region"""
mock_group = mocker.MagicMock()
result = gather_from_sequence_parallel_region(input_tensor, mock_group,
output_split_sizes)
# If output_split_sizes is not provided, result should have expanded first dimension by world size
if output_split_sizes is None:
expected_shape = [input_tensor.shape[0] * 4] + list(
input_tensor.shape[1:])
assert result.shape == torch.Size(expected_shape)
else:
# If output_split_sizes is provided, result shape is dictated by sum of output_split_sizes
expected_shape = [sum(output_split_sizes)] + list(
input_tensor.shape[1:])
assert result.shape == torch.Size(expected_shape)

View File

@@ -17,53 +17,40 @@ from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestFusedExpertsMoGE(TestBase):
class TestLoadWeight(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
mock_is_310p.return_value = False
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
mock_swiglu.side_effect = lambda x: x
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -0,0 +1,289 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.utils import vllm_version_is
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
def setUp(self):
# Mock FusedMoEConfig
self.moe_config = MagicMock(spec=FusedMoEConfig)
self.moe_config.tp_group = MagicMock()
self.moe_config.tp_group.device_group = MagicMock()
self.moe_config.dp_size = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
self.moe_config.dp_group = MagicMock()
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
return_value=1)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
mock_tp_size):
mock_context = MagicMock()
mock_context.mc2_mask = torch.tensor([1, 0, 1])
mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
# Check padding and split
self.assertEqual(h_out.shape[0], 4)
self.assertEqual(r_out.shape[0], 4)
self.assertEqual(mask.tolist(), [1, 0, 1])
# Finalize
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
return_value=2)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
@patch("torch.distributed.all_gather")
def test_mc2_tp_split_allgather(self, mock_all_gather,
mock_get_forward_context, mock_tp_rank,
mock_tp_size):
mock_context = MagicMock()
mock_context.mc2_mask = torch.tensor([1, 0, 1, 0])
mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
h_out, r_out, mask = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2)
# Mock all_gather behavior
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
layer.split_hidden_states = [
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
# Should concat back to original size
self.assertEqual(final_result.shape[0], 4)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
return_value=1)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
return_value=0)
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
# Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3)
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
return_value=2)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch("torch.distributed.all_gather")
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# Split due to TP=2
self.assertEqual(h_out.shape[0], 1)
# Mock all_gather
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
mock_all_gather.side_effect = mock_all_gather_func
layer.split_hidden_states = [
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_allgather_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce, mock_get_dp_group):
# Mock forward context
mock_context = MagicMock()
mock_context.max_tokens_across_dp = 6
mock_get_forward_context.return_value = mock_context
# Create a proper mock for DP group with working all_gather
mock_dp_group = MagicMock()
def mock_all_gather_func(tensor, dim):
# Simulate DP=2: repeat the tensor along the specified dimension
return torch.cat([tensor, tensor], dim=dim)
mock_dp_group.all_gather = mock_all_gather_func
mock_get_dp_group.return_value = mock_dp_group
self.moe_config.dp_size = 2
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
self.moe_config.dp_group = mock_dp_group
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock the gate function for rm_router_logits=False case
mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12)
self.assertEqual(r_out.shape[0], 12)
# Finalize with reduce_scatter
def mock_reduce_scatter_func(tensor, dim):
# Simulate reduce_scatter: take first half
return tensor[:3]
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
# Test with TP all-reduce
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce,
mock_get_dp_group):
# Mock forward context with DP metadata
mock_context = MagicMock()
if vllm_version_is("0.10.2"):
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[2, 5, 7])
else:
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
[2, 5, 7])
mock_get_forward_context.return_value = mock_context
# Setup DP group mock
mock_dp_group = MagicMock()
mock_dp_group.broadcast = MagicMock()
mock_dp_group.all_reduce = MagicMock()
mock_get_dp_group.return_value = mock_dp_group
# Mock all_reduce to just return input (simulate sum)
def mock_all_reduce(tensor):
return tensor * 2
mock_dp_group.all_reduce.side_effect = mock_all_reduce
# Setup config
self.moe_config.dp_size = 3
self.moe_config.dp_rank = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
# Local inputs
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock gate for router logits recomputation
mock_gate = MagicMock()
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))
self.assertEqual(r_out.shape, (7, 2))
# Run finalize
result = layer.finalize(h_out, reduce_results=False)
# Should slice back to local: [3, 8]
self.assertEqual(result.shape, (3, 8))
# Test with reduce_results=True and TP/EP > 1
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape, (3, 8))

View File

@@ -22,15 +22,13 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
adapt_patch(True)
@@ -58,122 +56,94 @@ def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
mock_hf_config.model_type = "llama"
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_moe_comm_method = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_moe_comm_method.prepare.side_effect = mock_prepare
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_fused_experts_result = torch.randn(16, 2)
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
def mock_finalize(hidden_states, **kwargs):
return hidden_states
captured_dispatchers = {}
mock_moe_comm_method.finalize.side_effect = mock_finalize
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
if vllm_version_is("0.10.2"):
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
else:
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None
)), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
'mock_moe_comm_method': mock_moe_comm_method,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
@@ -235,6 +205,8 @@ def default_moe_config():
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
moe.moe_parallel_config.use_ep = False
moe.moe_parallel_config.dp_size = 1
return AscendUnquantizedFusedMoEMethod(moe)
@@ -280,6 +252,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestAscendFusedMoe:
@@ -339,9 +314,7 @@ class TestAscendFusedMoe:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
@@ -395,25 +368,10 @@ class TestAscendUnquantizedFusedMoEMethod:
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
@@ -439,35 +397,22 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts=global_num_experts,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
@@ -494,8 +439,10 @@ class TestAscendUnquantizedFusedMoEMethod:
expert_map=expert_map,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@@ -524,10 +471,47 @@ class TestExpertsSelector:
assert topk_ids.shape == (8, 2)
class TestCumsumGroupList(TestBase):
def setUp(self):
self.active_num = 8
self.expert_num = 128
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
self.experts[:self.active_num] = 1
self.experts = self.experts[torch.randperm(self.expert_num)]
self.group_list = self.experts.cumsum(dim=0)
def test_cumsum_group_list_with_type_0(self):
group_list = self.experts.cumsum(dim=0)
group_list_type = 0
result = cumsum_group_list(group_list, group_list_type)
self.assertTrue(torch.equal(result, self.group_list))
def test_cumsum_group_list_with_type_1(self):
group_list = self.experts
group_list_type = 1
result = cumsum_group_list(group_list, group_list_type)
self.assertTrue(torch.equal(result, self.group_list))
def test_cumsum_group_list_with_type_2(self):
tokens = torch.arange(self.expert_num, dtype=torch.int64)
group_list = torch.cat([
tokens.reshape(self.expert_num, 1),
self.experts.reshape(self.expert_num, 1)
],
dim=1)
group_list_type = 2
result = cumsum_group_list(group_list,
group_list_type,
active_num=self.active_num,
expert_num=self.expert_num)
self.assertTrue(torch.equal(result, self.group_list))
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
@@ -538,7 +522,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_forward_context.moe_comm_type = MoECommType.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
@@ -582,8 +566,6 @@ class TestUnifiedApplyMLP(TestBase):
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()
@@ -593,7 +575,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -635,7 +617,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -695,7 +677,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -739,3 +721,68 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch_npu.npu_dynamic_quant")
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
mock_npu_swiglu, mock_npu_grouped_matmul,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
-128, 127, (10, 40),
dtype=torch.int8), torch.rand(
10, 1,
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 20, dtype=torch.bfloat16)
]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True,
fusion=True)
mock_get_forward_context.assert_called()
mock_npu_grouped_matmul.assert_called_once()
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)

View File

@@ -1,13 +1,18 @@
from unittest.mock import patch
import unittest
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.model_executor.layers.layernorm import RMSNorm
from tests.ut.base import PytestBase
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
def mock_maybe_chunk_residual(x, residual):
if x.size(0) != residual.size(0):
return residual[:4]
return residual
def mock_rms_norm(x, weight, eps):
@@ -18,36 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
return 2 * x, None, 2 * residual
@pytest.mark.parametrize("is_310p_return", [True, False])
@pytest.mark.parametrize("residual",
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
residual, dummy_tensor):
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
epsilon):
x_out = 2 * x
residual_out = 2 * residual
x_out_quant = x_out.to(torch.int8)
residual_out_quant = residual_out.to(torch.int8)
return x_out_quant, None, residual_out_quant
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = RMSNorm(hidden_size=32, eps=1e-05)
class TestAscendRMSNorm(PytestBase):
@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm",
side_effect=mock_add_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm_quant",
side_effect=mock_add_rms_norm_quant)
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
side_effect=lambda x: None)
# Test case for the most common and basic scenario
@pytest.mark.parametrize(
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
def test_forward_oot_basic(self, residual):
layer = RMSNorm(hidden_size=8, eps=1e-05)
x = torch.randn(4, 8, dtype=torch.float16)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
x_out, residual_out = layer.forward_oot(x, residual)
if is_310p_return:
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)
x_out_expected = 2 * x
residual_out_expected = 2 * residual
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
assert torch.allclose(x_out, x_out_expected)
assert torch.allclose(residual_out, residual_out_expected)
else:
out_x = layer.forward(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1
x_out = layer.forward(x, residual)
x_out_expected = x + 1
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(x_out, x_out_expected)
# Test case for flashcomm_v1 scenario
def test_forward_oot_with_flashcomm_v1(self):
layer = RMSNorm(hidden_size=512, eps=1e-05)
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
x_out, residual_out = layer.forward_oot(x, residual)
x_out_expected = 2 * x
residual_out_expected = 2 * residual[:4]
assert residual_out.size(0) == 4
assert torch.allclose(x_out, x_out_expected)
assert torch.allclose(residual_out, residual_out_expected)
# Test case for addrmsnorm + w8a8 quant fusion
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
mock_is_310p.return_value = False
mock_get_forward_context = mocker.patch(
"vllm_ascend.ops.layernorm.get_forward_context")
# Simulating a scenario with quant_fusion enabled
mock_forward_context = mocker.MagicMock()
mock_model_instance = mocker.MagicMock()
mock_forward_context.model_instance = mock_model_instance
mock_model_instance.model.layers = [
mocker.MagicMock() for _ in range(2)
]
mock_layer_0 = mock_model_instance.model.layers[0]
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()
mock_layer_1 = mock_model_instance.model.layers[1]
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()
mock_quant_method_0_qkv = mocker.MagicMock()
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
mock_quant_method_0_gate_up = mocker.MagicMock()
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up
mock_quant_method_1_qkv = mocker.MagicMock()
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
mock_quant_method_1_gate_up = mocker.MagicMock()
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up
mock_get_forward_context.return_value = mock_forward_context
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
mock_forward_context.prefetch_mlp_enabled = False
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = 2
mock_forward_context.fusion_linear = "gate_up_dense"
# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)
residual = torch.randn(4, 8, dtype=torch.float16)
layer = RMSNorm(hidden_size=8, eps=1e-05)
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 1
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 1
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 2
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 3
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2
if __name__ == '__main__':
unittest.main()

View File

@@ -1,363 +1,96 @@
import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
AscendMlpMergedColumnParallelLinear,
AscendMlpRowParallelLinear, LinearBase,
QuantizationConfig)
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
class TestAscendMlpRowParallelLinear(unittest.TestCase):
class BaseLinearTest(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
self.tensor_parallel_world_size = 2
self.tensor_parallel_rank = 0
self.mlp_tensor_parallel_world_size = 2
self.mlp_tensor_parallel_rank = 1
self.mock_group = mock.MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.get_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
return_value=self.tensor_parallel_world_size)
self.get_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
return_value=self.tensor_parallel_rank)
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
return_value=self.mlp_tensor_parallel_world_size)
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
return_value=self.mlp_tensor_parallel_rank)
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
self.get_tensor_model_parallel_world_size_mock = \
self.get_tensor_model_parallel_world_size_patch.start()
self.get_tensor_model_parallel_rank_mock = \
self.get_tensor_model_parallel_rank_patch.start()
self.get_mlp_tensor_model_parallel_world_size_mock = \
self.get_mlp_tensor_model_parallel_world_size_patch.start()
self.get_mlp_tensor_model_parallel_rank_mock = \
self.get_mlp_tensor_model_parallel_rank_patch.start()
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
self.split_tensor_along_last_dim_patch = mock.patch(
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
self.tensor_model_parallel_all_reduce_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_reduce_mock = \
self.tensor_model_parallel_all_reduce_patch.start()
self.split_tensor_along_last_dim_mock = \
self.split_tensor_along_last_dim_patch.start()
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
mock.MagicMock()
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
return_value=self.mock_group),
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
for p in self.patches:
p.start()
def tearDown(self):
self.get_tensor_model_parallel_world_size_patch.stop()
self.get_tensor_model_parallel_rank_patch.stop()
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
self.get_mlp_tensor_model_parallel_rank_patch.stop()
self.split_tensor_along_last_dim_patch.stop()
self.tensor_model_parallel_all_reduce_patch.stop()
self.get_mlp_tp_group_patch.stop()
for p in self.patches:
p.stop()
def test_init_with_down_proj_prefix(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
prefix="down_proj")
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
self.assertTrue(layer.enable_mlp_optimze)
def test_forward_with_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
layer(input_tensor)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_forward_without_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
def test_oproj_tp(self):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="other",
input_is_parallel=False,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
linear(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
def test_skip_bias_add(self):
layer = AscendMlpRowParallelLinear(
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
def test_merged_mlp_tp_init(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
linear = AscendMergedColumnParallelLinear(
input_size=16,
output_size=8,
skip_bias_add=True,
output_sizes=[8, 8],
prefix="gate_up_proj",
)
input_tensor = torch.randn(16, 8)
output, bias = layer(input_tensor)
self.assertIsNotNone(bias)
def test_no_reduce_results(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
def test_input_not_parallel(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
input_is_parallel=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once()
def test_exception_when_reduce_false_and_bias(self):
with self.assertRaises(ValueError):
AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=True,
skip_bias_add=False)
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock distributed functions
self.mlp_tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
self.mlp_tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
self.tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
self.tp_size_mock = self.tp_size_patch.start()
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
self.tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
self.tp_rank_mock = self.tp_rank_patch.start()
self.tp_rank_mock.return_value = 1 # Current GPU rank
# Mock divide function (assumed to be in your module)
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
self.divide_mock = self.divide_patch.start()
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
# Mock QuantizationConfig and QuantMethod
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
# Mock LinearBase initialization
self.linear_base_init_patch = mock.patch.object(
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
self.linear_base_init_patch.start()
self.quant_method_mock = mock.MagicMock()
def mock_linear_base_init(self, instance, *args, **kwargs):
instance.quant_method = self.quant_method_mock
instance.params_dtype = mock.MagicMock()
instance.input_size = 16
instance.output_size = 8
instance.output_size_per_partition = 4
instance.params_dtype = torch.float32
def tearDown(self):
self.mlp_tp_size_patch.stop()
self.mlp_tp_rank_patch.stop()
self.tp_size_patch.stop()
self.tp_rank_patch.stop()
self.divide_patch.stop()
self.linear_base_init_patch.stop()
def test_mlp_optimize_initialization(self):
# Test when prefix contains "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.gate_up_proj",
bias=False,
)
# Verify MLP optimization flags
self.assertTrue(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 2)
self.assertEqual(layer.tp_rank, 0)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_regular_parallel_initialization(self):
# Test when prefix does NOT contain "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.q_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify regular TP flags
self.assertFalse(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 4)
self.assertEqual(layer.tp_rank, 1)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_output_sizes_handling(self):
# Test when output_sizes is provided
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
output_sizes=[4, 4],
prefix="model.layers.0.qkv_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify output_partition_sizes
self.assertEqual(layer.output_partition_sizes, [2])
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
self.mlp_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
self.tensor_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
self.mlp_world_size_patch.start()
self.tensor_world_size_patch.start()
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
self.mlp_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
self.tensor_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
self.mlp_rank_patch.start()
self.tensor_rank_patch.start()
# Mock all_gather methods
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
self.tensor_model_parallel_all_gather_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_gather_mock = \
self.tensor_model_parallel_all_gather_patch.start()
# Mock AscendMlpColumnParallelLinear's __init__
self.linear_init_patch = mock.patch.object(
AscendMlpColumnParallelLinear,
"__init__",
side_effect=self.mock_linear_init)
self.linear_init_patch.start()
# Create mock objects
self.quant_method_mock = mock.MagicMock()
self.apply_output = torch.randn(2, 8)
self.quant_method_mock.apply.return_value = self.apply_output
def mock_linear_init(self, instance, *args, **kwargs):
torch.nn.Module.__init__(instance)
# Set quant_method and other attributes
instance.quant_method = self.quant_method_mock
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
instance.input_size = 16
instance.output_size = 8
instance.gather_output = False
instance.skip_bias_add = False
instance.return_bias = True
def test_forward_with_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.assertEqual(output.shape, self.apply_output.shape)
def test_forward_without_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix not containing "gate_up_proj"
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.quant_method_mock.apply.assert_called_once_with(
layer, input_tensor, layer.bias)
self.tensor_model_parallel_all_gather_mock.assert_not_called()
self.assertEqual(output.shape, self.apply_output.shape)
def tearDown(self):
self.linear_init_patch.stop()
self.mlp_world_size_patch.stop()
self.tensor_world_size_patch.stop()
self.mlp_rank_patch.stop()
self.tensor_rank_patch.stop()
self.get_mlp_tp_group_mock.stop()
self.tensor_model_parallel_all_gather_mock.stop()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,232 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from tests.ut.base import TestBase
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl)
class TestMoECommMethod(TestBase):
def setUp(self):
# Mock FusedMoEConfig
self.moe_config = MagicMock(spec=FusedMoEConfig)
self.moe_config.num_experts = 8
self.moe_config.num_local_experts = 2
self.moe_config.experts_per_token = 2
self.moe_config.tp_group = MagicMock()
self.moe_config.tp_group.device_group = MagicMock()
self.moe_config.dp_size = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = AllGatherCommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2),
torch.tensor([1, 0, 1, 0]))
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = MC2CommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_token_dispatcher.return_value = mock_td_instance
# Create instance
comm_impl = AlltoAllCommImpl(self.moe_config)
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
mock_get_forward_context.return_value = mock_context
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = {
"hidden_states": torch.randn(6, 8),
"group_list": torch.tensor([2, 2, 2]),
"group_list_type": 1
}
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp
mock_unified_apply_mlp.return_value = torch.randn(6, 8)
# Create instance
comm_impl = AllGatherCommImpl(self.moe_config)
# Test fused_experts method
hidden_states = torch.randn(4, 8).contiguous()
w1 = torch.randn(16, 8).contiguous()
w2 = torch.randn(16, 8).contiguous()
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
row_idx = torch.arange(4)
# Make sure tensors are contiguous and have correct strides
hidden_states = hidden_states.contiguous()
w1 = w1.contiguous()
w2 = w2.contiguous()
result = comm_impl.fused_experts(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
activation="silu")
# Verify result shape
self.assertEqual(result.shape, (4, 8))
# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()
# Verify unified_apply_mlp was called
mock_unified_apply_mlp.assert_called_once()
# Verify token_combine was called
mock_td_instance.token_combine.assert_called_once()

View File

@@ -3,12 +3,18 @@ import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from transformers.configuration_utils import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
MODEL = "Qwen3-0.6B"
MAX_NUM_BATCHED_TOKEND = 10000
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
@@ -88,11 +94,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('torch.ops._C')
@patch('torch.ops._C_ascend')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock__c):
@@ -102,9 +112,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
@@ -113,6 +129,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
@@ -121,15 +141,25 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
@@ -137,26 +167,78 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
self.layer.forward(self.positions, self.query, self.key,
offsets)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
result_q, result_k = self.layer.forward(self.positions,
self.query,
self.key,
is_neox_style_override=False)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# test case when rotary_dim < head_size
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
# restore rotary_dim
self.layer.rotary_dim = org_rotary_dim
class MockRopeModule:
@@ -207,28 +289,6 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_cache_handling(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
self.layer.max_seq_len = 1024
# Test cache situation is true
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions,
self.query,
self.key,
max_seq_len=2048)
mock_set_cache.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))

View File

@@ -20,10 +20,10 @@ from unittest.mock import MagicMock, PropertyMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
class TestTokenDispatcherWithMC2(TestBase):
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
@@ -98,7 +98,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.row_idx, expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
0) # group_list_type == 0
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
self.mock_moe_init_routing.return_value = (
self.patcher_npu_moe_init_routing_v2 = patch(
'torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
)
self.mock_npu_moe_init_routing_v2.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
)
self.patcher_moe_compute_expert_tokens = patch(
'torch_npu.npu_moe_compute_expert_tokens')
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
)
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
[3, 3]) # expert_tokens
self.patcher_moe_finalize_routing = patch(
'torch_npu.npu_moe_finalize_routing')
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.row_idx = torch.arange(10, dtype=torch.int32)
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
def tearDown(self):
self.patcher_moe_init_routing.stop()
self.patcher_moe_compute_expert_tokens.stop()
self.patcher_moe_finalize_routing.stop()
self.patcher_npu_moe_init_routing_v2.stop()
self.patcher_npu_moe_token_unpermute.stop()
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
@@ -207,12 +200,27 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 0)
self.assertEqual(results["group_list_type"], 1)
def test_token_dispatch_with_quant(self):
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
def test_token_dispatch_without_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
@@ -230,7 +238,33 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights, topk_ids,
self.row_idx, None)
self.assertEqual(results["group_list_type"], 0)
self.assertEqual(results["group_list_type"], 1)
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
self.row_idx,
None,
with_quant=True)
self.assertIsNotNone(results["hidden_states"])
self.assertIsNotNone(results["group_list"])
self.assertIsNotNone(results["dynamic_scale"])
self.assertEqual(results["group_list_type"], 1)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -242,9 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify index_add_ is applied correctly
self.assertEqual(final_hidden_states.shape, (3, 128))
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
@@ -260,10 +292,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
self.mock_moe_finalize_routing.assert_called_once()
args, kwargs = self.mock_moe_finalize_routing.call_args
self.mock_npu_moe_token_unpermute.assert_called_once()
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
self.assertEqual(final_hidden_states.shape, (3, 128))
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
@@ -315,7 +347,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
@@ -323,7 +355,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
@@ -488,119 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
@@ -31,6 +32,9 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
mock_vllm_config = MagicMock()
mock_vllm_config.additional_config = {}
init_ascend_config(mock_vllm_config)
def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
@@ -206,7 +210,15 @@ class TestAscendLogitsProcessor(unittest.TestCase):
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
return_value=torch.randn(1, self.vocab_size))
return_value=torch.randn(1, self.vocab_size)),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather",
return_value=torch.randn(1, self.vocab_size)),
patch(
"vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config",
return_value=MagicMock(max_num_batched_tokens=1000,
max_model_len=512,
enable_chunked_prefill=False))
]
for p in self.patches: