[Refactor][MoE] Reuse vLLM's all_reduce logic (#5189)
### What this PR does / why we need it?
Move all_reduce logic to AscendFusedMoE.forward, reuse vLLM's logic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -169,15 +169,12 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.max_tokens_across_dp = 6
|
||||
@@ -222,7 +219,5 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
# Test with TP all-reduce
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
Reference in New Issue
Block a user