[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -17,14 +17,62 @@
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEAllToAllCombineMetadata,
|
||||
MoEMC2CombineMetadata,
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
MoETokenDispatchInput,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendDeviceType, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
AscendDeviceType,
|
||||
TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
def build_token_dispatch_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
comm_quant_mode: int | None = None,
|
||||
act_quant_type: torch.dtype | None = None,
|
||||
) -> MoETokenDispatchInput:
|
||||
mxfp_spec = None
|
||||
if quant_type == QuantType.MXFP8:
|
||||
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
|
||||
return MoETokenDispatchInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
routing=MoERoutingParams(
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
pertoken_scale=pertoken_scale,
|
||||
),
|
||||
quant=MoEQuantParams(
|
||||
quant_type=quant_type,
|
||||
comm_quant_mode=comm_quant_mode,
|
||||
mxfp=mxfp_spec,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
def test_init(self):
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
|
||||
@@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=0,
|
||||
apply_router_weight_on_input=False,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
@@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5 +
|
||||
(None, None)) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": mc2_mask,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"tp_recv_counts": tp_recv_counts
|
||||
}
|
||||
combine_metadata = MoEMC2CombineMetadata(
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
expert_map=expert_map,
|
||||
ep_recv_counts=ep_recv_counts,
|
||||
tp_recv_counts=tp_recv_counts,
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
expand_scales=None,
|
||||
dispatch_with_quant=True,
|
||||
)
|
||||
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.moe_expert_num = len(expert_map)
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
combine_metadata)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
|
||||
@@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
@@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(results.hidden_states)
|
||||
self.assertIsNotNone(results.group_list)
|
||||
@@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine_with_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata).routed_out
|
||||
combine_metadata = MoEAllGatherCombineMetadata(
|
||||
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
restore_shape=torch.Size([6, 128]),
|
||||
)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine_without_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata).routed_out
|
||||
combine_metadata = MoEAllGatherCombineMetadata(
|
||||
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
restore_shape=torch.Size([6, 128]),
|
||||
)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
self.assertEqual(results.hidden_states.shape, (6, 128))
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
@@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
context_metadata = {
|
||||
"input_splits": [4, 4],
|
||||
"output_splits": [4, 4],
|
||||
"topk_weights": torch.rand(8, 4),
|
||||
"reversed_local_input_permutation_mapping": torch.arange(8),
|
||||
"reversed_global_input_permutation_mapping": torch.arange(16),
|
||||
}
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
combine_metadata = MoEAllToAllCombineMetadata(
|
||||
input_splits=np.array([4, 4]),
|
||||
output_splits=np.array([4, 4]),
|
||||
topk_weights=torch.rand(8, 4),
|
||||
reversed_local_input_permutation_mapping=torch.arange(8),
|
||||
reversed_global_input_permutation_mapping=torch.arange(16),
|
||||
hidden_shape=torch.Size([8, 16]),
|
||||
hidden_shape_before_permute=torch.Size([8, 16]),
|
||||
)
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
output = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.routed_out.shape, (8, 16))
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user