[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -4,12 +4,21 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (
|
||||
AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEFusedExpertsInput,
|
||||
MoEPrepareOutput,
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||
TokenDispatchResult)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -45,8 +54,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -60,8 +72,9 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -70,7 +83,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@@ -86,10 +99,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2),
|
||||
torch.tensor([1, 0, 1,
|
||||
0]), None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=torch.tensor([1, 0, 1, 0]),
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -103,8 +117,9 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -113,7 +128,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@@ -133,8 +148,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -148,8 +166,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
_ = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -174,19 +191,27 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1)
|
||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
||||
routed_out=torch.randn(4, 8))
|
||||
dispatch_topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]])
|
||||
mock_td_instance.token_dispatch.return_value = MoETokenDispatchOutput(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1,
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=dispatch_topk_weights,
|
||||
expanded_row_idx=torch.arange(8, dtype=torch.int32),
|
||||
restore_shape=torch.Size([4, 8]),
|
||||
))
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
@@ -199,8 +224,7 @@ class TestMoECommMethod(TestBase):
|
||||
hidden_states = torch.randn(4, 8).contiguous()
|
||||
w1 = torch.randn(16, 8).contiguous()
|
||||
w2 = torch.randn(16, 8).contiguous()
|
||||
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
||||
[0.6, 0.4]])
|
||||
topk_weights = dispatch_topk_weights
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
||||
|
||||
# Make sure tensors are contiguous and have correct strides
|
||||
@@ -208,12 +232,25 @@ class TestMoECommMethod(TestBase):
|
||||
w1 = w1.contiguous()
|
||||
w2 = w2.contiguous()
|
||||
|
||||
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
||||
w1=[w1],
|
||||
w2=[w2],
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation="silu")
|
||||
result = comm_impl.fused_experts(fused_experts_input=MoEFusedExpertsInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
weights=MoEWeights(
|
||||
w1=[w1],
|
||||
w2=[w2],
|
||||
),
|
||||
routing=MoERoutingParams(
|
||||
expert_map=None,
|
||||
global_redundant_expert_num=0,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=False,
|
||||
),
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=False,
|
||||
quant=MoEQuantParams(),
|
||||
))
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||
@@ -223,6 +260,12 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify unified_apply_mlp was called
|
||||
mock_unified_apply_mlp.assert_called_once()
|
||||
mlp_compute_input = mock_unified_apply_mlp.call_args.kwargs["mlp_compute_input"]
|
||||
self.assertFalse(mlp_compute_input.fusion)
|
||||
self.assertFalse(mlp_compute_input.quant.is_mxfp)
|
||||
|
||||
# Verify token_combine was called
|
||||
mock_td_instance.token_combine.assert_called_once()
|
||||
mock_td_instance.token_combine.assert_called_once_with(
|
||||
hidden_states=mock_unified_apply_mlp.return_value,
|
||||
combine_metadata=mock_td_instance.token_dispatch.return_value.combine_metadata,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user