[v0.11.0] [Bugfix] [MoE]fix error in deepseek when using allgather (#3827)

### What this PR does / why we need it?
After refactoring vllm_ascend/models and FusedMoE, we are unable to pass
`gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result
in error when running deepseek v3/r1 with allgather.
Hence, this pr removes `gate` related computations from FusedMoE module
in eager/aclgraph mode.
### Does this PR introduce _any_ user-facing change?
`rm_router_logits` is deprecated in eager/aclgraph.
### How was this patch tested?
e2e & ut

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-30 14:59:46 +08:00
committed by GitHub
parent 211d4b9da4
commit c506ba60fb
7 changed files with 98 additions and 115 deletions

View File

@@ -191,13 +191,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock the gate function for rm_router_logits=False case
mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
# After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12)
@@ -258,14 +252,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock gate for router logits recomputation
mock_gate = MagicMock()
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))

View File

@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
hidden_states, router_logits, False, False)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
hidden_states, router_logits, False, False)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
hidden_states, router_logits, False, False)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")