From 4544e99d88aed9247381a420c896d39fced69096 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 17 Feb 2025 11:42:33 +0800 Subject: [PATCH] [dist] revert communicator patch (#66) ### What this PR does / why we need it? Revert communicator patch as https://github.com/vllm-project/vllm/pull/13208 has been merged. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? test locally by https://github.com/vllm-project/vllm-ascend/pull/30#issuecomment-2650251266 Signed-off-by: MengqingCao --- vllm_ascend/communicator.py | 70 ++++++-------------------- vllm_ascend/patch/__init__.py | 18 ------- vllm_ascend/patch/patch_commnicator.py | 69 ------------------------- vllm_ascend/worker.py | 2 - 4 files changed, 14 insertions(+), 145 deletions(-) delete mode 100644 vllm_ascend/patch/__init__.py delete mode 100644 vllm_ascend/patch/patch_commnicator.py diff --git a/vllm_ascend/communicator.py b/vllm_ascend/communicator.py index afb39f7..543b639 100644 --- a/vllm_ascend/communicator.py +++ b/vllm_ascend/communicator.py @@ -14,65 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Optional import torch import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm.distributed.device_communicators.base_device_communicator import \ + DeviceCommunicatorBase -class NPUCommunicator: +class NPUCommunicator(DeviceCommunicatorBase): - def __init__(self, group, unique_name=""): - self.group = group - self.unique_name = unique_name - self.rank = dist.get_rank(group) - self.world_size = dist.get_world_size(self.group) - self.ranks = dist.get_process_group_ranks(self.group) - global_rank = dist.get_rank() - self.rank_in_group = dist.get_group_rank(self.group, global_rank) - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - dist.all_reduce(x, group=self.group) - return x - - def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1): - # NOTE: We assume that the input tensor is on the same device across - # all the ranks. - # NOTE: `dst` is the local rank of the destination rank. - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [ - torch.empty_like(input_) for _ in range(self.world_size) - ] - else: - gather_list = None - # Gather. - dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # NOTE: we have to use concat-style all-gather here, - # stack-style all-gather has compatibility issues with - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * self.world_size, ) + input_size[1:] - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - dist.all_gather_into_tensor(output_tensor, input_, group=self.group) - # Reshape - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + # init device according to local rank + local_rank = dist.get_rank(device_group) + self.device = torch.device(f"npu:{local_rank}") diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py deleted file mode 100644 index f03d4b4..0000000 --- a/vllm_ascend/patch/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from vllm_ascend.patch import patch_commnicator # noqa diff --git a/vllm_ascend/patch/patch_commnicator.py b/vllm_ascend/patch/patch_commnicator.py deleted file mode 100644 index 15a8563..0000000 --- a/vllm_ascend/patch/patch_commnicator.py +++ /dev/null @@ -1,69 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is used to monkey patch communicator in vllm to support ascend. -# Remove this file when vllm support by -# https://github.com/vllm-project/vllm/pull/11324. - -import torch -import vllm -from vllm.utils import resolve_obj_by_qualname - - -class GroupCoordinatorPatch(vllm.distributed.parallel_state.GroupCoordinator): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.device = torch.device(f"npu:{self.local_rank}") - - from vllm.platforms import current_platform - device_comm_cls = resolve_obj_by_qualname( - current_platform.get_device_communicator_cls()) - # we have checked and ensure that reusing tpu tag here is fine. - use_custom_device = kwargs.get("use_tpu_communicator", False) - if use_custom_device and self.world_size > 1: - self.communicator = device_comm_cls(group=self.device_group, - unique_name=self.unique_name) - - def all_reduce(self, input_): - # Bypass the function if we are using only 1 device. - if self.world_size == 1: - return input_ - - return self.communicator.all_reduce(input_) - - def gather(self, input_, dst=0, dim=-1): - # Bypass the function if we are using only 1 device. - if self.world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - - return self.communicator.gather(input_, dst, dim) - - def all_gather(self, input_, dim=-1): - # Bypass the function if we are using only 1 device. - if self.world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - return self.communicator.all_gather(input_, dim) - - -vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index cecff11..c5884e3 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -457,8 +457,6 @@ def init_worker_distributed_environment( backend: str = "hccl") -> None: """Initialize the distributed environment.""" set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - # register communicator patch before init dist env - from vllm_ascend import patch # noqa: F401 init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend)