Files
xc-llm-ascend/vllm_ascend/patch/patch_commnicator.py
Mengqing Cao b88443b6c6 [dist] fix communicator patch (#58)
### What this PR does / why we need it?
fix communicator patch so parallel could work.
see #52

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-02-14 10:45:49 +08:00

70 lines
2.7 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is used to monkey patch communicator in vllm to support ascend.
# Remove this file when vllm support by
# https://github.com/vllm-project/vllm/pull/11324.
import torch
import vllm
from vllm.utils import resolve_obj_by_qualname
class GroupCoordinatorPatch(vllm.distributed.parallel_state.GroupCoordinator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.device = torch.device(f"npu:{self.local_rank}")
from vllm.platforms import current_platform
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
# we have checked and ensure that reusing tpu tag here is fine.
use_custom_device = kwargs.get("use_tpu_communicator", False)
if use_custom_device and self.world_size > 1:
self.communicator = device_comm_cls(group=self.device_group,
unique_name=self.unique_name)
def all_reduce(self, input_):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_
return self.communicator.all_reduce(input_)
def gather(self, input_, dst=0, dim=-1):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
return self.communicator.gather(input_, dst, dim)
def all_gather(self, input_, dim=-1):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.communicator.all_gather(input_, dim)
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch