[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -22,11 +22,15 @@ import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
AscendUnquantizedFusedMoEMethod,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -56,7 +60,73 @@ def mock_npu_format_cast(weight_data, format):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
# init dist env patch
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
|
||||
captured_dispatchers = {}
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
@@ -66,12 +136,10 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
|
||||
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
@@ -82,22 +150,31 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
)), \
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)):
|
||||
yield
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
# init moe env patch
|
||||
|
||||
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
||||
torch.randn(8, 2),
|
||||
@@ -144,7 +221,6 @@ def mock_moe_env(mocker: MockerFixture):
|
||||
|
||||
@pytest.fixture
|
||||
def default_moe_config():
|
||||
"""default moe config"""
|
||||
return {
|
||||
'num_experts': 8,
|
||||
'top_k': 2,
|
||||
@@ -188,7 +264,6 @@ class MockQuantMethod(nn.Module):
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
# TODO(bnell): also pass quant_config?
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
@@ -223,13 +298,11 @@ class TestAscendFusedMoe:
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
# check group_topk
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
# check scoring_func
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
@@ -254,14 +327,7 @@ class TestAscendFusedMoe:
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
"""
|
||||
1 test has shared_experts
|
||||
2 test has top_k
|
||||
3 test is_prefill is true
|
||||
4 test single num_tokens(decode)
|
||||
5 test ep_size is 1 and is_prefill is true
|
||||
6 test ep_size is 1 and is_prefill is False
|
||||
"""
|
||||
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
@@ -327,25 +393,42 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
|
||||
2 test use_select_experts and fused_experts
|
||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||
"""
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
@@ -354,29 +437,38 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[16, False], [1, True], [1, False], [4, False]])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test use_select_experts and use fused_expters_with_mc2
|
||||
2 test use_select_experts and fused_experts_with_all2all_buffer
|
||||
3 test use_select_experts and fused_experts_with_all2all
|
||||
4 test use_select_experts and fused_experts
|
||||
"""
|
||||
|
||||
ep_size, alltoall_buffer = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||
alltoall_buffer), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
@@ -386,8 +478,16 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
if alltoall_buffer:
|
||||
moe_method.max_model_len = 1
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
@@ -397,10 +497,9 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 16 or ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
@@ -426,3 +525,239 @@ class TestExpertsSelector:
|
||||
|
||||
assert topk_weights.shape == (8, 2)
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.get_mc2_group')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p,
|
||||
mock_get_mc2_group,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_mc2_group = MagicMock()
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 20),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_dequant.return_value = (torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16),
|
||||
torch.randn(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.float32)
|
||||
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
|
||||
mock_npu_dequant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.float16)
|
||||
], [torch.randn(10, 20, dtype=torch.float16)]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertFalse(mock_forward_context.with_quant)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.bfloat16)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
mock_npu_dynamic_quant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization_310p(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = True
|
||||
|
||||
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
|
||||
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
|
||||
[mock_gmm2_out]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertFalse(mock_forward_context.with_quant)
|
||||
mock_is_310p.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
Reference in New Issue
Block a user