[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -17,14 +17,62 @@
from unittest.mock import MagicMock, PropertyMock, patch
import numpy as np
import pytest
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEMC2CombineMetadata,
MoEQuantParams,
MoERoutingParams,
MoETokenDispatchInput,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendDeviceType, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
AscendDeviceType,
TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
def build_token_dispatch_input_fixture(
*,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False,
pertoken_scale: torch.Tensor | None = None,
quant_type: QuantType = QuantType.NONE,
comm_quant_mode: int | None = None,
act_quant_type: torch.dtype | None = None,
) -> MoETokenDispatchInput:
mxfp_spec = None
if quant_type == QuantType.MXFP8:
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
return MoETokenDispatchInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
),
quant=MoEQuantParams(
quant_type=quant_type,
comm_quant_mode=comm_quant_mode,
mxfp=mxfp_spec,
),
)
class TestTokenDispatcherWithMC2(TestBase):
@@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase):
def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
@@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
global_redundant_expert_num=0,
apply_router_weight_on_input=False,
pertoken_scale=None,
)
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
@@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase):
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 +
(None, None)) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
expert_map)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
)
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": mc2_mask,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"tp_recv_counts": tp_recv_counts
}
combine_metadata = MoEMC2CombineMetadata(
topk_ids=topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
ep_recv_counts=ep_recv_counts,
tp_recv_counts=tp_recv_counts,
assist_info_for_combine=assist_info_for_combine,
expand_scales=None,
dispatch_with_quant=True,
)
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.moe_expert_num = len(expert_map)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata)
combine_metadata)
self.assertIn("tp_send_counts", kwargs)
@@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
None)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.group_list_type, 1)
@@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
None,
with_quant=True)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type=QuantType.W8A8,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
@@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase):
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata).routed_out
combine_metadata = MoEAllGatherCombineMetadata(
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
restore_shape=torch.Size([6, 128]),
)
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_without_expert_map(self):
hidden_states = torch.randn(6, 128)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata).routed_out
combine_metadata = MoEAllGatherCombineMetadata(
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
restore_shape=torch.Size([6, 128]),
)
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=True,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.hidden_states.shape, (6, 128))
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
context_metadata = {
"input_splits": [4, 4],
"output_splits": [4, 4],
"topk_weights": torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16),
}
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
combine_metadata = MoEAllToAllCombineMetadata(
input_splits=np.array([4, 4]),
output_splits=np.array([4, 4]),
topk_weights=torch.rand(8, 4),
reversed_local_input_permutation_mapping=torch.arange(8),
reversed_global_input_permutation_mapping=torch.arange(16),
hidden_shape=torch.Size([8, 16]),
hidden_shape_before_permute=torch.Size([8, 16]),
)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
output = self.dispatcher.token_combine(hidden_states, context_metadata)
output = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.routed_out.shape, (8, 16))
self.assertEqual(output.shape, (8, 16))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
with_quant=True)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
@@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
with_quant=True)
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)