### What this PR does / why we need it? 1. Replace manual memory cleanup with passing parameter. 2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated code. 3. Fix MC2 bug introduced in https://github.com/vllm-project/vllm-ascend/pull/3365 4. Unify aclgraph & eager in W8A8_dynamic. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -44,7 +44,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
@@ -52,7 +53,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -77,10 +80,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
@@ -96,7 +100,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
@@ -112,12 +118,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -133,10 +142,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
@@ -152,7 +162,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
@@ -195,9 +207,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
@@ -209,7 +221,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
return tensor[:3]
|
||||
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@@ -263,9 +277,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
|
||||
Reference in New Issue
Block a user