[MoE] [Refactor] Remove manual memory cleanup (#3365)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-15 12:36:24 +08:00
committed by GitHub
parent 4e720936d8
commit 4f937f561d
8 changed files with 562 additions and 492 deletions

View File

@@ -44,7 +44,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states, router_logits)
# Check padding and split
self.assertEqual(h_out.shape[0], 4)
@@ -52,7 +53,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(mask.tolist(), [1, 0, 1])
# Finalize
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -77,10 +80,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
h_out, r_out, mask = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2)
@@ -96,7 +100,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back to original size
self.assertEqual(final_result.shape[0], 4)
@@ -112,12 +118,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
# Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3)
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -133,10 +142,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# Split due to TP=2
self.assertEqual(h_out.shape[0], 1)
@@ -152,7 +162,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@@ -195,9 +207,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
# After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12)
@@ -209,7 +221,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
return tensor[:3]
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@@ -263,9 +277,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))