[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)

### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.1.1
- vLLM main:
704432af3c

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-28 10:13:35 +08:00
committed by GitHub
parent 936c102105
commit 320edde2df
10 changed files with 1066 additions and 1639 deletions

View File

@@ -22,7 +22,6 @@ from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2DecoderLayer, CustomDeepseekV2ForCausalLM,
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
@@ -115,7 +114,8 @@ def mock_distributed():
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
patch("torch.npu.current_device", return_value=0):
yield
@@ -266,54 +266,3 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128))
base_config.n_routed_experts = 4
layer = CustomDeepseekV2DecoderLayer(config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
cache_config=CacheConfig(),
quant_config=None)
assert isinstance(layer.mlp, CustomDeepseekV2MoE)
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
hidden_states, residual = layer(positions, x, None)
assert hidden_states.shape == (2, 4, 128)
base_config.n_routed_experts = None
layer = CustomDeepseekV2DecoderLayer(config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
quant_config=None)
assert isinstance(layer.mlp, CustomDeepseekV2MLP)
def test_custom_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
model = CustomDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4))
positions = torch.arange(4).repeat(2, 1)
with patch.object(model.model,
"forward",
return_value=torch.randn(2, 4, 128)):
output = model(input_ids, positions)
assert output.shape == (2, 4, 128)
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
with patch(
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
assert loaded is not None

View File

@@ -22,11 +22,15 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
AscendUnquantizedFusedMoEMethod,
unified_apply_mlp)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
@@ -56,7 +60,73 @@ def mock_npu_format_cast(weight_data, format):
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
# init dist env patch
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
@@ -66,12 +136,10 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
return_value=torch.randn(5, 32)), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
return_value=torch.randn(5, 32)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -82,22 +150,31 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=MagicMock(
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
)), \
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)):
yield
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
# init moe env patch
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
@@ -144,7 +221,6 @@ def mock_moe_env(mocker: MockerFixture):
@pytest.fixture
def default_moe_config():
"""default moe config"""
return {
'num_experts': 8,
'top_k': 2,
@@ -188,7 +264,6 @@ class MockQuantMethod(nn.Module):
class MockFusedMoEMethod(FusedMoEMethodBase):
# TODO(bnell): also pass quant_config?
moe = MagicMock()
def __init__(self):
@@ -223,13 +298,11 @@ class TestAscendFusedMoe:
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
# check group_topk
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config)
# check scoring_func
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
@@ -254,14 +327,7 @@ class TestAscendFusedMoe:
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
"""
1 test has shared_experts
2 test has top_k
3 test is_prefill is true
4 test single num_tokens(decode)
5 test ep_size is 1 and is_prefill is true
6 test ep_size is 1 and is_prefill is False
"""
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
@@ -327,25 +393,42 @@ class TestAscendUnquantizedFusedMoEMethod:
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
@@ -354,29 +437,38 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts=global_num_experts,
is_prefill=is_prefill)
if ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param",
[[16, False], [1, True], [1, False], [4, False]])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test use_select_experts and use fused_expters_with_mc2
2 test use_select_experts and fused_experts_with_all2all_buffer
3 test use_select_experts and fused_experts_with_all2all
4 test use_select_experts and fused_experts
"""
ep_size, alltoall_buffer = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
alltoall_buffer), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
@@ -386,8 +478,16 @@ class TestAscendUnquantizedFusedMoEMethod:
if alltoall_buffer:
moe_method.max_model_len = 1
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
@@ -397,10 +497,9 @@ class TestAscendUnquantizedFusedMoEMethod:
expert_map=expert_map,
is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
expected_shape = (16, 2)
assert result.shape == expected_shape
class TestExpertsSelector:
@@ -426,3 +525,239 @@ class TestExpertsSelector:
assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.get_mc2_group')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_is_310p,
mock_get_mc2_group,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_mc2_group = MagicMock()
mock_get_mc2_group.return_value = mock_mc2_group
mock_is_310p.return_value = False
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 20),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_dequant.return_value = (torch.randn(10,
40,
dtype=torch.bfloat16),
torch.randn(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
w1_scale = torch.randn(5, 40, dtype=torch.float32)
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None)
mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_dequant.assert_called_once()
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.float16)
], [torch.randn(10, 20, dtype=torch.float16)]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)
mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.bfloat16)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None)
mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = True
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
[mock_gmm2_out]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)
mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
mock_is_310p.assert_called_once()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)

View File

@@ -25,8 +25,8 @@ from tests.ut.base import PytestBase, TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
from vllm_ascend.utils import adapt_patch # noqa E402
TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
get_token_dispatcher, setup_token_dispatchers)
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
@@ -90,7 +90,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
@@ -100,28 +100,18 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
# Mock get_ascend_config()
self.ascend_config = MagicMock()
self.ascend_config.torchair_graph_config.enabled = False
self.ascend_config_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
return_value=self.ascend_config)
self.ascend_config_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
self.ascend_config_patch.stop()
def test_init(self):
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.torchair_graph_enabled)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
@@ -149,9 +139,10 @@ class TestTokenDispatcherWithMC2(TestBase):
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
expert_map)
self.row_idx, expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output[0], 1) # group_list_type == 1
self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
@@ -166,20 +157,13 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True) as mock_wait:
self.dispatcher.token_dispatch(
self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
mock_wait.assert_any_call(self.hidden_states,
self.topk_weights)
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
@@ -213,13 +197,7 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True):
self.dispatcher.token_combine(self.hidden_states)
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(TestBase):
@@ -257,6 +235,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.patcher_moe_init_routing.stop()
@@ -268,14 +247,14 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.assertEqual(group_list_type, 0)
self.assertEqual(results["group_list_type"], 0)
def test_token_dispatch_with_quant(self):
kwargs = {
@@ -292,11 +271,11 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
# Verify quant mode returns group_list_type=1
self.assertEqual(group_list_type, 0)
self.assertEqual(results["group_list_type"], 0)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -337,19 +316,9 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
self.assertEqual(sorted_hidden_states.shape, (6, 128))
def test_token_dispatch_invalid_topk_when_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
with self.assertRaises(AssertionError):
self.dispatcher.token_dispatch(
hidden_states, topk_weights,
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -443,6 +412,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
@@ -457,6 +427,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -504,6 +475,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -532,6 +504,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@@ -553,9 +526,126 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -1,82 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
row_idx_len = self.num_tokens * 8
row_idx = (torch.arange(
0,
row_idx_len,
dtype=torch.int32,
).view(8, -1).permute(1, 0).contiguous())
result = fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
row_idx=row_idx,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))