Revert "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483)

This reverts commit 4f937f561d.

### What this PR does / why we need it?
This reverts commit 4f937f561d.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-15 22:25:46 +08:00
committed by GitHub
parent f69a83b7ba
commit cec1fab509
8 changed files with 500 additions and 572 deletions

View File

@@ -44,8 +44,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states, router_logits)
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
# Check padding and split
self.assertEqual(h_out.shape[0], 4)
@@ -53,9 +52,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(mask.tolist(), [1, 0, 1])
# Finalize
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -80,11 +77,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out, r_out, mask = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2)
@@ -100,9 +96,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
final_result = layer.finalize(h_out, reduce_results=False)
# Should concat back to original size
self.assertEqual(final_result.shape[0], 4)
@@ -118,15 +112,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
# Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
@patch(
@@ -142,11 +133,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
# Split due to TP=2
self.assertEqual(h_out.shape[0], 1)
@@ -162,9 +152,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
final_result = layer.finalize(h_out, reduce_results=False)
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@@ -207,9 +195,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
# After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12)
@@ -221,9 +209,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
return tensor[:3]
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
result = layer.finalize(h_out, reduce_results=False)
self.assertEqual(result.shape[0], 3)
@@ -277,9 +263,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))