add gatherep select. (#2740)
### What this PR does / why we need it?
add gatherep select.
- vLLM version: v0.10.1.1
- vLLM main:
e599e2c65e
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
22
tests/ut/ops/test_ascend_forwad_context.py
Normal file
22
tests/ut/ops/test_ascend_forwad_context.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import get_dispatcher_name
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDispatcherName(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_get_dispatcher_name(self):
|
||||||
|
result = get_dispatcher_name(1, False)
|
||||||
|
assert result == "TokenDispatcherWithAllGather"
|
||||||
|
result = get_dispatcher_name(4, False)
|
||||||
|
assert result == "TokenDispatcherWithAll2AllV"
|
||||||
|
result = get_dispatcher_name(16, True)
|
||||||
|
assert result == "TokenDispatcherWithAll2AllV"
|
||||||
|
result = get_dispatcher_name(16, False)
|
||||||
|
assert result == "TokenDispatcherWithMC2"
|
||||||
|
with mock.patch.dict(os.environ,
|
||||||
|
{"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}):
|
||||||
|
result = get_dispatcher_name(16, False)
|
||||||
|
assert result == "TokenDispatcherWithAllGather"
|
||||||
@@ -45,13 +45,12 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
|||||||
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
|
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
|
||||||
if ep_size == 1:
|
if ep_size == 1:
|
||||||
return "TokenDispatcherWithAllGather"
|
return "TokenDispatcherWithAllGather"
|
||||||
|
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1:
|
||||||
if ep_size < 16:
|
return "TokenDispatcherWithAllGather"
|
||||||
|
elif ep_size < 16 or with_prefill:
|
||||||
return "TokenDispatcherWithAll2AllV"
|
return "TokenDispatcherWithAll2AllV"
|
||||||
|
else:
|
||||||
if with_prefill:
|
return "TokenDispatcherWithMC2"
|
||||||
return "TokenDispatcherWithAll2AllV"
|
|
||||||
return "TokenDispatcherWithMC2"
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed.parallel_state import get_ep_group
|
from vllm.distributed.parallel_state import get_ep_group
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.distributed.tensor_parallel import \
|
from vllm_ascend.distributed.tensor_parallel import \
|
||||||
gather_from_sequence_parallel_region
|
gather_from_sequence_parallel_region
|
||||||
@@ -50,6 +51,9 @@ def setup_token_dispatchers(ep_size: int, **kwargs):
|
|||||||
|
|
||||||
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
||||||
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
||||||
|
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \
|
||||||
|
and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
||||||
|
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
||||||
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||||
elif ep_size >= 16:
|
elif ep_size >= 16:
|
||||||
|
|||||||
Reference in New Issue
Block a user