diff --git a/tests/ut/ops/test_ascend_forwad_context.py b/tests/ut/ops/test_ascend_forwad_context.py new file mode 100644 index 0000000..17e3c6f --- /dev/null +++ b/tests/ut/ops/test_ascend_forwad_context.py @@ -0,0 +1,22 @@ +import os +import unittest +from unittest import mock + +from vllm_ascend.ascend_forward_context import get_dispatcher_name + + +class TestGetDispatcherName(unittest.TestCase): + + def test_get_dispatcher_name(self): + result = get_dispatcher_name(1, False) + assert result == "TokenDispatcherWithAllGather" + result = get_dispatcher_name(4, False) + assert result == "TokenDispatcherWithAll2AllV" + result = get_dispatcher_name(16, True) + assert result == "TokenDispatcherWithAll2AllV" + result = get_dispatcher_name(16, False) + assert result == "TokenDispatcherWithMC2" + with mock.patch.dict(os.environ, + {"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}): + result = get_dispatcher_name(16, False) + assert result == "TokenDispatcherWithAllGather" diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 601f33a..31822af 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -45,13 +45,12 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str: if ep_size == 1: return "TokenDispatcherWithAllGather" - - if ep_size < 16: + elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1: + return "TokenDispatcherWithAllGather" + elif ep_size < 16 or with_prefill: return "TokenDispatcherWithAll2AllV" - - if with_prefill: - return "TokenDispatcherWithAll2AllV" - return "TokenDispatcherWithMC2" + else: + return "TokenDispatcherWithMC2" @contextmanager diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 90b2209..a5c5566 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -28,6 +28,7 @@ import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group +import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.tensor_parallel import \ gather_from_sequence_parallel_region @@ -50,6 +51,9 @@ def setup_token_dispatchers(ep_size: int, **kwargs): if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers: _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) + elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \ + and "TokenDispatcherWithAllGather" not in existing_dispatchers: + _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers: _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) elif ep_size >= 16: