Files
xc-llm-ascend/tests/ut/ops/test_fused_moe.py
LI SHENGYONG f81cf694b2 [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.

**Before refactoring**

-   map path is not None:

expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

-   map path is None:

expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.

**Refactoring**
eplb_utils.py
    init_eplb_config
         generate placement
         generate expert map
         generate log2phy

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

dsr1 baselie:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

dsr1 eplb:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00

594 lines
25 KiB
Python

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp)
from vllm_ascend.utils import AscendDeviceType, adapt_patch
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
mock_hf_config.model_type = "llama"
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock()
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits
mock_moe_comm_method.prepare.side_effect = mock_prepare
mock_fused_experts_result = torch.randn(16, 2)
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
def mock_finalize(hidden_states, **kwargs):
return hidden_states
mock_moe_comm_method.finalize.side_effect = mock_finalize
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_weight_prefetch_method = MagicMock()
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
mc2_mask=torch.zeros(
16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
return_value=MagicMock(
enable_multistream_moe=False,
expert_map_path=None
)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.init_eplb_config',
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=mock_weight_prefetch_method):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_moe_comm_method': mock_moe_comm_method,
}
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
moe.moe_parallel_config.use_ep = False
moe.moe_parallel_config.dp_size = 1
return AscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [256, 128])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
use_grouped_topk=False,
renormalize=True,
topk_group=None,
num_expert_group=None,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
global_num_experts=global_num_experts)
assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
class TestCumsumGroupList(TestBase):
def setUp(self):
self.active_num = 8
self.expert_num = 128
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
self.experts[:self.active_num] = 1
self.experts = self.experts[torch.randperm(self.expert_num)]
self.group_list = self.experts.cumsum(dim=0)
def test_cumsum_group_list_with_type_0(self):
group_list = self.experts.cumsum(dim=0)
group_list_type = 0
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))
def test_cumsum_group_list_with_type_1(self):
group_list = self.experts
group_list_type = 1
result = cumsum_group_list(group_list, group_list_type, 0)
self.assertTrue(torch.equal(result, self.group_list))
def test_cumsum_group_list_with_type_2(self):
tokens = torch.arange(self.expert_num, dtype=torch.int64)
group_list = torch.cat([
tokens.reshape(self.expert_num, 1),
self.experts.reshape(self.expert_num, 1)
],
dim=1)
group_list_type = 2
result = cumsum_group_list(group_list,
group_list_type,
0,
active_num=self.active_num,
expert_num=self.expert_num)
self.assertTrue(torch.equal(result, self.group_list))
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_soc_version,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.moe_comm_type = MoECommType.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 20),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_dequant.return_value = (torch.randn(10,
40,
dtype=torch.bfloat16),
torch.randn(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
w1_scale = torch.randn(5, 40, dtype=torch.float32)
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
mock_npu_dynamic_quant.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_dequant.assert_called_once()
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(self,
mock_npu_dynamic_quant,
mock_npu_swiglu,
mock_npu_grouped_matmul,
mock_soc_version):
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.float16)
], [torch.randn(10, 20, dtype=torch.float16)]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.bfloat16)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_soc_version):
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
[mock_gmm2_out]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch_npu.npu_dynamic_quant")
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
mock_npu_swiglu, mock_npu_grouped_matmul,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
-128, 127, (10, 40),
dtype=torch.int8), torch.rand(
10, 1,
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 20, dtype=torch.bfloat16)
]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True,
fusion=True)
mock_get_forward_context.assert_called()
mock_npu_grouped_matmul.assert_called_once()
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)