[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
patcher6 = patch(
|
||||
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
|
||||
Reference in New Issue
Block a user