[v0.11.0] [Bugfix] [MoE]fix error in deepseek when using allgather (#3827)
### What this PR does / why we need it? After refactoring vllm_ascend/models and FusedMoE, we are unable to pass `gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result in error when running deepseek v3/r1 with allgather. Hence, this pr removes `gate` related computations from FusedMoE module in eager/aclgraph mode. ### Does this PR introduce _any_ user-facing change? `rm_router_logits` is deprecated in eager/aclgraph. ### How was this patch tested? e2e & ut Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -191,13 +191,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock the gate function for rm_router_logits=False case
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
@@ -258,14 +252,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock gate for router logits recomputation
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
hidden_states, router_logits, False, False)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
hidden_states, router_logits, False, False)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
hidden_states, router_logits, False, False)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
|
||||
Reference in New Issue
Block a user