init v0.11.0rc0
This commit is contained in:
@@ -22,15 +22,13 @@ import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -58,122 +56,94 @@ def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.model_type = "llama"
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
mock_moe_comm_method = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||
return hidden_states, router_logits
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
mock_fused_experts_result = torch.randn(16, 2)
|
||||
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
def mock_finalize(hidden_states, **kwargs):
|
||||
return hidden_states
|
||||
|
||||
captured_dispatchers = {}
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
if vllm_version_is("0.10.2"):
|
||||
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
else:
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(
|
||||
16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
'mock_moe_comm_method': mock_moe_comm_method,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
@@ -235,6 +205,8 @@ def default_moe_config():
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
moe.moe_parallel_config.use_ep = False
|
||||
moe.moe_parallel_config.dp_size = 1
|
||||
return AscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
@@ -280,6 +252,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestAscendFusedMoe:
|
||||
|
||||
@@ -339,9 +314,7 @@ class TestAscendFusedMoe:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
@@ -395,25 +368,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
@@ -439,35 +397,22 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
@@ -494,8 +439,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
@@ -524,10 +471,47 @@ class TestExpertsSelector:
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestCumsumGroupList(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.active_num = 8
|
||||
self.expert_num = 128
|
||||
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
|
||||
self.experts[:self.active_num] = 1
|
||||
self.experts = self.experts[torch.randperm(self.expert_num)]
|
||||
self.group_list = self.experts.cumsum(dim=0)
|
||||
|
||||
def test_cumsum_group_list_with_type_0(self):
|
||||
group_list = self.experts.cumsum(dim=0)
|
||||
group_list_type = 0
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_1(self):
|
||||
group_list = self.experts
|
||||
group_list_type = 1
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_2(self):
|
||||
tokens = torch.arange(self.expert_num, dtype=torch.int64)
|
||||
group_list = torch.cat([
|
||||
tokens.reshape(self.expert_num, 1),
|
||||
self.experts.reshape(self.expert_num, 1)
|
||||
],
|
||||
dim=1)
|
||||
group_list_type = 2
|
||||
result = cumsum_group_list(group_list,
|
||||
group_list_type,
|
||||
active_num=self.active_num,
|
||||
expert_num=self.expert_num)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -538,7 +522,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_forward_context.moe_comm_type = MoECommType.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
@@ -582,8 +566,6 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
@@ -593,7 +575,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -635,7 +617,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -695,7 +677,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -739,3 +721,68 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
|
||||
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
|
||||
mock_npu_swiglu, mock_npu_grouped_matmul,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
|
||||
-128, 127, (10, 40),
|
||||
dtype=torch.int8), torch.rand(
|
||||
10, 1,
|
||||
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True,
|
||||
fusion=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
mock_npu_grouped_matmul.assert_called_once()
|
||||
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
||||
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user