[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
mask = prepare_output.mc2_mask
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
self.assertEqual(r_out.shape[0], 4)
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
prepare_output = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
mask = prepare_output.mc2_mask
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||
|
||||
# Mock all_gather behavior
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
@@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
@@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
|
||||
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
prepare_output = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
|
||||
|
||||
# Mock all_gather
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
@@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
@@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
self.assertEqual(r_out.shape[0], 12)
|
||||
self.assertIsNone(padded_hidden_states_shape)
|
||||
|
||||
# Finalize with reduce_scatter
|
||||
def mock_reduce_scatter_func(tensor, dim):
|
||||
@@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user