diff --git a/setup.py b/setup.py index 2553521..d278ef9 100644 --- a/setup.py +++ b/setup.py @@ -95,7 +95,7 @@ setup( "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ], - packages=find_packages(exclude=("docs", "examples", "tests*", "patch")), + packages=find_packages(exclude=("docs", "examples", "tests*")), python_requires=">=3.9", install_requires=get_requirements(), extras_require={}, diff --git a/vllm_ascend/patch/patch_commnicator.py b/vllm_ascend/patch/patch_commnicator.py index 45f3495..15a8563 100644 --- a/vllm_ascend/patch/patch_commnicator.py +++ b/vllm_ascend/patch/patch_commnicator.py @@ -19,11 +19,11 @@ # https://github.com/vllm-project/vllm/pull/11324. import torch -from vllm.distributed.parallel_state import GroupCoordinator +import vllm from vllm.utils import resolve_obj_by_qualname -class GroupCoordinatorPatch(GroupCoordinator): +class GroupCoordinatorPatch(vllm.distributed.parallel_state.GroupCoordinator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -66,4 +66,4 @@ class GroupCoordinatorPatch(GroupCoordinator): return self.communicator.all_gather(input_, dim) -GroupCoordinator = GroupCoordinatorPatch +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 242cf52..2b847de 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -88,9 +88,8 @@ class NPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - # Register ops and patch when setup. + # Register ops when setup. from vllm_ascend import ops # noqa: F401 - from vllm_ascend import patch # noqa: F401 parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index c5884e3..cecff11 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -457,6 +457,8 @@ def init_worker_distributed_environment( backend: str = "hccl") -> None: """Initialize the distributed environment.""" set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + # register communicator patch before init dist env + from vllm_ascend import patch # noqa: F401 init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend)