### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
594 lines
25 KiB
Python
594 lines
25 KiB
Python
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
from typing import List, TypedDict
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from pytest_mock import MockerFixture
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
|
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
|
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
|
unified_apply_mlp)
|
|
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
|
|
|
adapt_patch(True)
|
|
|
|
|
|
def mock_ep_and_mc2_group(mocker):
|
|
mock_group = mocker.MagicMock()
|
|
mock_group.rank_in_group = 0
|
|
mock_group.rank = 0
|
|
mock_group.world_size = 4
|
|
mock_group.device_group = "mock_group_ep"
|
|
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
|
return mock_group
|
|
|
|
|
|
def mock_dp_and_tp_group(mocker):
|
|
mock_group = mocker.MagicMock()
|
|
mock_group.rank_in_group = 0
|
|
mock_group.world_size = 2
|
|
mock_group.device_group = "mock_group"
|
|
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
|
return mock_group
|
|
|
|
|
|
def mock_npu_format_cast(weight_data, format):
|
|
return weight_data
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_vllm_config_mock(mocker: MockerFixture):
|
|
mock_hf_config = MagicMock()
|
|
mock_hf_config.model_type = "llama"
|
|
|
|
mock_model_config = MagicMock()
|
|
mock_model_config.hf_config = mock_hf_config
|
|
|
|
mock_vllm_config = MagicMock()
|
|
mock_vllm_config.model_config = mock_model_config
|
|
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
|
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
|
mock_vllm_config.model_config.max_model_len = 2048
|
|
|
|
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
|
return_value=mock_vllm_config)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dist_env(mocker: MockerFixture):
|
|
mock_moe_comm_method = MagicMock()
|
|
|
|
def mock_prepare(hidden_states, router_logits, **kwargs):
|
|
return hidden_states, router_logits
|
|
|
|
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
|
|
|
mock_fused_experts_result = torch.randn(16, 2)
|
|
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
|
|
|
|
def mock_finalize(hidden_states, **kwargs):
|
|
return hidden_states
|
|
|
|
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
|
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
|
mock_weight_prefetch_method = MagicMock()
|
|
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
|
moe_comm_type=MoECommType.MC2,
|
|
max_tokens_across_dp=10,
|
|
dp_metadata=dp_metadata,
|
|
mc2_mask=torch.zeros(
|
|
16, dtype=torch.bool),
|
|
padded_num_tokens=16,
|
|
with_quant=False)
|
|
|
|
with patch('torch.distributed.get_rank', return_value=0), \
|
|
patch('torch.distributed.get_world_size', return_value=4), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
|
return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
|
|
return_value=MagicMock(
|
|
enable_multistream_moe=False,
|
|
expert_map_path=None
|
|
)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.init_eplb_config',
|
|
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
|
|
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
|
return_value=mock_weight_prefetch_method):
|
|
|
|
yield {
|
|
'mock_forward_context_obj': mock_forward_context_obj,
|
|
'mock_moe_comm_method': mock_moe_comm_method,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_moe_env(mocker: MockerFixture):
|
|
|
|
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
None
|
|
)), \
|
|
patch('torch_npu.npu_moe_init_routing', return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
|
)), \
|
|
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
|
torch.randn(8, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_grouped_matmul", return_value=(
|
|
[torch.randn(16, 2)]
|
|
)), \
|
|
patch("torch_npu.npu_swiglu", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
|
)), \
|
|
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
|
torch.randn(16, 2)
|
|
)):
|
|
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
|
torch.randn(16, 2))), \
|
|
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
|
torch.randn(16, 2))):
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def default_moe_config():
|
|
return {
|
|
'num_experts': 8,
|
|
'top_k': 2,
|
|
'hidden_size': 512,
|
|
'intermediate_size': 1024
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def moe_method(mock_dist_env):
|
|
moe = MagicMock()
|
|
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
|
moe.moe_parallel_config.use_ep = False
|
|
moe.moe_parallel_config.dp_size = 1
|
|
return AscendUnquantizedFusedMoEMethod(moe)
|
|
|
|
|
|
class Device(TypedDict):
|
|
device_id: int
|
|
device_expert: List[int]
|
|
|
|
|
|
class Layer(TypedDict):
|
|
layer_id: int
|
|
device_count: int
|
|
device_list: List[Device]
|
|
|
|
|
|
class MockData(TypedDict):
|
|
moe_layer_count: int
|
|
layer_list: List[Layer]
|
|
|
|
|
|
class MockQuantMethod(nn.Module):
|
|
|
|
def __init__(self, shared_experts, num_tokens):
|
|
super().__init__()
|
|
if shared_experts:
|
|
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
|
torch.randn(num_tokens, 10)))
|
|
else:
|
|
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
|
|
|
|
|
class MockFusedMoEMethod(FusedMoEMethodBase):
|
|
moe = MagicMock()
|
|
|
|
def __init__(self):
|
|
super().__init__(self.moe)
|
|
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
pass
|
|
|
|
def apply(self, hidden_states: torch.Tensor,
|
|
expert_weights: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
|
pass
|
|
|
|
|
|
class TestExpertsSelector:
|
|
|
|
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
|
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
|
global_num_experts):
|
|
|
|
x = torch.randn(8, 2)
|
|
router_logits = torch.randn(8, 2)
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=2,
|
|
use_grouped_topk=False,
|
|
renormalize=True,
|
|
topk_group=None,
|
|
num_expert_group=None,
|
|
custom_routing_function=None,
|
|
scoring_func="softmax",
|
|
e_score_correction_bias=None,
|
|
global_num_experts=global_num_experts)
|
|
|
|
assert topk_weights.shape == (8, 2)
|
|
assert topk_ids.shape == (8, 2)
|
|
|
|
|
|
class TestCumsumGroupList(TestBase):
|
|
|
|
def setUp(self):
|
|
self.active_num = 8
|
|
self.expert_num = 128
|
|
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
|
|
self.experts[:self.active_num] = 1
|
|
self.experts = self.experts[torch.randperm(self.expert_num)]
|
|
self.group_list = self.experts.cumsum(dim=0)
|
|
|
|
def test_cumsum_group_list_with_type_0(self):
|
|
group_list = self.experts.cumsum(dim=0)
|
|
group_list_type = 0
|
|
result = cumsum_group_list(group_list, group_list_type, 0)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
def test_cumsum_group_list_with_type_1(self):
|
|
group_list = self.experts
|
|
group_list_type = 1
|
|
result = cumsum_group_list(group_list, group_list_type, 0)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
def test_cumsum_group_list_with_type_2(self):
|
|
tokens = torch.arange(self.expert_num, dtype=torch.int64)
|
|
group_list = torch.cat([
|
|
tokens.reshape(self.expert_num, 1),
|
|
self.experts.reshape(self.expert_num, 1)
|
|
],
|
|
dim=1)
|
|
group_list_type = 2
|
|
result = cumsum_group_list(group_list,
|
|
group_list_type,
|
|
0,
|
|
active_num=self.active_num,
|
|
expert_num=self.expert_num)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
|
|
class TestUnifiedApplyMLP(TestBase):
|
|
|
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
return_value=AscendDeviceType.A3)
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
@patch('torch_npu.npu_dequant_swiglu_quant')
|
|
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
|
|
mock_npu_dynamic_quant,
|
|
mock_npu_grouped_matmul,
|
|
mock_soc_version,
|
|
mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.moe_comm_type = MoECommType.MC2
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 20),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
|
|
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
|
|
|
mock_npu_dequant.return_value = (torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16),
|
|
torch.randn(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.float32)
|
|
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=None,
|
|
with_quant=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
|
|
mock_npu_dynamic_quant.assert_called()
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
|
|
mock_npu_dequant.assert_called_once()
|
|
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|
|
|
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
return_value=AscendDeviceType.A3)
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_without_quantization(self,
|
|
mock_npu_dynamic_quant,
|
|
mock_npu_swiglu,
|
|
mock_npu_grouped_matmul,
|
|
mock_soc_version):
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 40, dtype=torch.float16)
|
|
], [torch.randn(10, 20, dtype=torch.float16)]]
|
|
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
|
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=None,
|
|
w2=w2,
|
|
w2_scale=None,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=topk_scales,
|
|
with_quant=False)
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
|
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
|
|
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
|
mock_npu_grouped_matmul, mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.with_quant = True
|
|
mock_forward_context.fused_moe_state = "NOT_MC2"
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 40, dtype=torch.bfloat16)
|
|
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
|
|
|
mock_npu_swiglu.return_value = torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16)
|
|
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 40),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
hidden_states_shape = hidden_states.shape
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=provided_dynamic_scale,
|
|
group_list_type=1,
|
|
w1_scale_bias=w1_scale_bias,
|
|
w2_scale_bias=w2_scale_bias,
|
|
topk_scales=None,
|
|
with_quant=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
mock_npu_dynamic_quant.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states_shape)
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|
|
|
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
return_value=AscendDeviceType._310P)
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_without_quantization_310p(
|
|
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
|
mock_npu_grouped_matmul, mock_soc_version):
|
|
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
|
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
|
|
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
|
|
[mock_gmm2_out]]
|
|
|
|
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
|
|
|
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=None,
|
|
w2=w2,
|
|
w2_scale=None,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=topk_scales,
|
|
with_quant=False)
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
|
@patch("torch_npu.npu_grouped_matmul")
|
|
@patch("torch_npu.npu_swiglu")
|
|
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
|
@patch("torch_npu.npu_dynamic_quant")
|
|
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
|
|
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
|
|
mock_npu_swiglu, mock_npu_grouped_matmul,
|
|
mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.with_quant = True
|
|
mock_forward_context.fused_moe_state = "NOT_MC2"
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
|
|
-128, 127, (10, 40),
|
|
dtype=torch.int8), torch.rand(
|
|
10, 1,
|
|
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 20, dtype=torch.bfloat16)
|
|
]]
|
|
mock_npu_swiglu.return_value = torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16)
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 40),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
hidden_states_shape = hidden_states.shape
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=provided_dynamic_scale,
|
|
group_list_type=1,
|
|
w1_scale_bias=w1_scale_bias,
|
|
w2_scale_bias=w2_scale_bias,
|
|
topk_scales=None,
|
|
with_quant=True,
|
|
fusion=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
mock_npu_grouped_matmul.assert_called_once()
|
|
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
|
|
|
self.assertTrue(mock_forward_context.with_quant)
|
|
self.assertEqual(result.shape, hidden_states_shape)
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|