[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states, router_logits)
prepare_output = layer.prepare(hidden_states, router_logits)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Check padding and split
self.assertEqual(h_out.shape[0], 4)
self.assertEqual(r_out.shape[0], 4)
self.assertEqual(mask.tolist(), [1, 0, 1])
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
# Finalize
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
h_out, r_out, mask, context_metadata = layer.prepare(
prepare_output = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2)
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
# Mock all_gather behavior
def mock_all_gather_func(tensor_list, tensor, group=None):
@@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
]
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
padded_hidden_states_shape=padded_hidden_states_shape)
# Should concat back to original size
self.assertEqual(final_result.shape[0], 4)
@@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
prepare_output = layer.prepare(hidden_states, router_logits)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3)
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
h_out, r_out, _, context_metadata = layer.prepare(
prepare_output = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Split due to TP=2
self.assertEqual(h_out.shape[0], 1)
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
# Mock all_gather
def mock_all_gather_func(tensor_list, tensor, group=None):
@@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
]
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
padded_hidden_states_shape=padded_hidden_states_shape)
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
prepare_output = layer.prepare(hidden_states, router_logits)
h_out = prepare_output.hidden_states
r_out = prepare_output.router_logits
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12)
self.assertEqual(r_out.shape[0], 12)
self.assertIsNone(padded_hidden_states_shape)
# Finalize with reduce_scatter
def mock_reduce_scatter_func(tensor, dim):
@@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
padded_hidden_states_shape=padded_hidden_states_shape)
self.assertEqual(result.shape[0], 3)