diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 52bfe13..e73e66a 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -62,7 +62,7 @@ # inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy # platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed # How: -# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup` +# Call `vllm_ascend.distributed.parallel_state method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup` # Related PR (if no, explain why): no related PR, we want add this ability into vllm # Future Plan: # Remove those patch when vllm merged them @@ -73,7 +73,7 @@ # call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support # stateless process group initialize method # How: -# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize +# rewrite stateless_init_torch_distributed_process_group to judge if there is a stateless process group initialize # method and call platform method `platform_register_backend` to initialize them # Related PR (if no, explain why): no related PR, we want add this ability into vllm # Future Plan: diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index ce43836..5dd5c66 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -42,8 +42,9 @@ def ascend_destroy_model_parallel(): if _DP: _DP.destroy() _DP = None - from vllm.platforms import current_platform - current_platform.destroy_platform_model_parallel() + from vllm_ascend.distributed.parallel_state import \ + destory_ascend_model_parallel + destory_ascend_model_parallel() def ascend_stateless_init_torch_distributed_process_group( @@ -100,7 +101,6 @@ def ascend_stateless_init_torch_distributed_process_group( group_rank, group_size, ) - from vllm.platforms import current_platform if backend == "gloo": from torch.distributed.distributed_c10d import ProcessGroupGloo backend_class = ProcessGroupGloo(prefix_store, @@ -120,8 +120,18 @@ def ascend_stateless_init_torch_distributed_process_group( backend_options) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") - elif current_platform.platform_has_backend_register(): - current_platform.platform_register_backend() + elif backend == "hccl": + from torch.distributed import is_hccl_available + assert is_hccl_available() + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, + backend_options) + device = torch.device("npu") + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + pg._register_backend(device, backend_type, backend_class) return pg else: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 6055a85..cea4fe6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -226,29 +226,3 @@ class NPUPlatform(Platform): model configuration. """ return True - - @classmethod - def destroy_platform_model_parallel(cls) -> None: - from vllm_ascend.distributed.parallel_state import \ - destory_ascend_model_parallel - destory_ascend_model_parallel() - - @classmethod - def platform_has_backend_register(cls) -> bool: - return True - - @classmethod - def platform_register_backend(cls, pg, prefix_store, group_rank, - group_size, backend_options, - timeout) -> None: - from torch.distributed import ProcessGroup, is_hccl_available - assert is_hccl_available() - from torch_npu._C._distributed_c10d import ProcessGroupHCCL - backend_options = ProcessGroupHCCL.Options() - backend_options._timeout = timeout - backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, - backend_options) - device = torch.device("npu") - backend_class._set_sequence_number_for_group() - backend_type = ProcessGroup.BackendType.CUSTOM - pg._register_backend(device, backend_type, backend_class)