[main] [refactor] refactor common_fused_moe.py (#2706)

### What this PR does / why we need it?
1. Move prepare/finalize operation from moe_comm_method to
/ops/moe/fused_moe_prepare_and_finalize
2. Adapt to token_dispatcher in moe_comm_method
3. Move
moe_comm_method/experts_selector/token_dispatcher/fused_moe_prepare_and_finalize
to /ops/moe
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut

- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-08 20:09:50 +08:00
committed by GitHub
parent 1a82b16355
commit a041d4f328
21 changed files with 1052 additions and 932 deletions

View File

@@ -22,14 +22,14 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
@@ -110,11 +110,11 @@ def mock_dist_env(mocker: MockerFixture):
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
@@ -158,7 +158,7 @@ def mock_dist_env(mocker: MockerFixture):
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
yield {
@@ -562,8 +562,8 @@ class TestCumsumGroupList(TestBase):
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
@@ -629,7 +629,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -671,7 +671,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -731,7 +731,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -776,7 +776,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context")
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")