[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -28,9 +28,10 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
@@ -182,7 +183,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
):
|
||||
context_mock = MagicMock()
|
||||
context_mock.fused_moe_state = 0
|
||||
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context",
|
||||
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
|
||||
return_value=context_mock):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||
@@ -282,9 +283,9 @@ def test_select_experts(
|
||||
dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
|
||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk, \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
@@ -318,7 +319,7 @@ def test_select_experts(
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
|
||||
pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
|
||||
@@ -90,9 +90,9 @@ def mock_distributed():
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
_gather_along_first_dim, async_all_to_all,
|
||||
gather_from_sequence_parallel_region)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -24,9 +24,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (
|
||||
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch(
|
||||
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -105,37 +108,37 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
@@ -319,8 +322,8 @@ class TestCumsumGroupList(TestBase):
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -384,7 +387,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -426,7 +429,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -486,7 +489,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -531,7 +534,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@@ -595,3 +598,39 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -4,8 +4,9 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -24,12 +25,14 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -72,12 +75,11 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
@@ -121,12 +123,14 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -163,13 +167,15 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
|
||||
@@ -4,13 +4,12 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
class TestPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
@@ -24,14 +23,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.original_num_experts = 8
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
@@ -39,7 +36,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -59,14 +56,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
@@ -76,7 +71,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
@@ -108,13 +103,13 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
@@ -130,15 +125,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
@@ -169,14 +164,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_allgather_prepare_finalize(self, mock_get_forward_context,
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
@@ -198,7 +194,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = mock_dp_group
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -232,13 +228,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
@@ -266,7 +260,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
patcher6 = patch(
|
||||
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
@@ -758,7 +758,7 @@ class TestSelectExperts(TestBase):
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
patcher = patch(
|
||||
'vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=self.mock_ctx)
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
@@ -831,7 +831,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
|
||||
@@ -87,7 +87,8 @@ def set_ascend_forward_context(
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.linear import AscendLinearBase
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.common_fused_moe # noqa
|
||||
import vllm_ascend.ops.fused_moe.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
|
||||
@@ -35,8 +35,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||
is_enable_nz, npu_stream_switch,
|
||||
shared_expert_dp_enabled,
|
||||
@@ -24,15 +24,13 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
TokenDispatcherWithMoge)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
@@ -59,8 +57,7 @@ class MoECommMethod(ABC):
|
||||
self.moe_config = moe_config
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
self.prepare_finalize = self._get_prepare_finalize()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -71,7 +68,7 @@ class MoECommMethod(ABC):
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
@@ -80,8 +77,9 @@ class MoECommMethod(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results, context_metadata)
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states,
|
||||
reduce_results,
|
||||
context_metadata)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
@@ -169,9 +167,9 @@ class MoECommMethod(ABC):
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
def _get_prepare_finalize(self):
|
||||
raise NotImplementedError(
|
||||
"_get_fused_moe_prepare_finalize function not implemented.")
|
||||
"_get_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
@@ -205,8 +203,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
@@ -222,8 +220,8 @@ class MC2CommImpl(MoECommMethod):
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
@@ -242,8 +240,8 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
@@ -271,5 +269,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
class PrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
@@ -103,7 +103,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
@@ -195,7 +195,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
|
||||
All2All and share the same finalize method.
|
||||
@@ -275,7 +275,7 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
@@ -429,7 +429,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
@@ -28,7 +28,7 @@ import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
|
||||
vllm_version_is)
|
||||
|
||||
|
||||
@@ -538,8 +538,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
|
||||
AscendQuantRMSNorm, AscendRMSNorm)
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
|
||||
Reference in New Issue
Block a user