diff --git a/tests/multicard/test_pyhccl_distributed.py b/tests/multicard/test_pyhccl_distributed.py new file mode 100644 index 0000000..1b35d0f --- /dev/null +++ b/tests/multicard/test_pyhccl_distributed.py @@ -0,0 +1,111 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import multiprocessing +import os + +import torch +import torch_npu # noqa: F401 +from vllm.distributed.parallel_state import (get_world_group, + init_distributed_environment) +from vllm.utils import update_environment_variables + +from vllm_ascend.distributed.device_communicators.pyhccl import \ + PyHcclCommunicator + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[multiprocessing.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"npu:{local_rank}") + torch.npu.set_device(device) + init_distributed_environment(backend="hccl") + fn() + + return wrapped_fn + + +@worker_fn_wrapper +def worker_fn(): + pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).npu(pynccl_comm.rank) + tensor = pynccl_comm.all_reduce(tensor) + torch.npu.synchronize() + assert torch.all(tensor == pynccl_comm.world_size).cpu().item() + + +# def test_pyhccl(): +# distributed_run(worker_fn, 2) + + +@worker_fn_wrapper +def broadcast_worker_fn(): + # Test broadcast for every root rank. + # Essentially this is an all-gather operation. + pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + recv_tensors = [ + torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=pyhccl_comm.device) + for i in range(pyhccl_comm.world_size) + ] + recv_tensors[pyhccl_comm.rank] = torch.ones( + 16, 1024, 1024, dtype=torch.float32, + device=pyhccl_comm.device) * pyhccl_comm.rank + + for i in range(pyhccl_comm.world_size): + pyhccl_comm.broadcast(recv_tensors[i], src=i) + # the broadcast op might be launched in a different stream + # need to synchronize to make sure the tensor is ready + torch.npu.synchronize() + assert torch.all(recv_tensors[i] == i).cpu().item() + + +# def test_pyhccl_broadcast(): +# distributed_run(broadcast_worker_fn, 4) diff --git a/tests/singlecard/test_pyhccl.py b/tests/singlecard/test_pyhccl.py new file mode 100644 index 0000000..8183b70 --- /dev/null +++ b/tests/singlecard/test_pyhccl.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +import torch_npu # noqa: F401 + +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import \ + HCCLLibrary + + +def test_hcclGetUniqueId(): + torch.npu.set_device(0) + lib = HCCLLibrary() + unique_id = lib.hcclGetUniqueId() + assert unique_id is not None diff --git a/vllm_ascend/distributed/communicator.py b/vllm_ascend/distributed/communicator.py index 0c43f1f..1c25db8 100644 --- a/vllm_ascend/distributed/communicator.py +++ b/vllm_ascend/distributed/communicator.py @@ -30,5 +30,6 @@ class NPUCommunicator(DeviceCommunicatorBase): device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) + # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # init device according to rank self.device = torch.npu.current_device() diff --git a/vllm_ascend/distributed/device_communicators/__init__.py b/vllm_ascend/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/distributed/device_communicators/pyhccl.py b/vllm_ascend/distributed/device_communicators/pyhccl.py new file mode 100644 index 0000000..3c0ea87 --- /dev/null +++ b/vllm_ascend/distributed/device_communicators/pyhccl.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch_npu # noqa: F401 +from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import logger + +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum, + hcclRedOpTypeEnum, hcclUniqueId) +from vllm_ascend.utils import current_stream + + +class PyHcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyHcclCommunicator to. If None, + it will be bind to f"npu:{local_rank}". + library_path: the path to the HCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.HCCL, ( + "PyHcclCommunicator should be attached to a non-HCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + + try: + self.hccl = HCCLLibrary(library_path) + except Exception: + # disable because of missing HCCL library + # e.g. in a non-NPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using pyhccl") + + if isinstance(device, int): + device = torch.device(f"npu:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + if self.rank == 0: + # get the unique id from HCCL + with torch.npu.device(device): + self.unique_id = self.hccl.hcclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = hcclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + + # hccl communicator and stream will use this device + # `torch.npu.device` is a context manager that changes the + # current npu device to the specified one + with torch.npu.device(device): + self.comm: hcclComm_t = self.hccl.hcclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # hccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + hcclRedOpTypeEnum.from_torch(op), self.comm, + aclrtStream_t(stream.npu_stream)) + return out_tensor + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + buffer = buffer_type(tensor.data_ptr()) + else: + buffer = buffer_type(tensor.data_ptr()) + self.hccl.hcclBroadcast(buffer, tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, aclrtStream_t(stream.npu_stream)) diff --git a/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py new file mode 100644 index 0000000..3435cc2 --- /dev/null +++ b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp +from vllm.logger import logger + +from vllm_ascend.utils import find_hccl_library + +# export types and functions from hccl to Python === +# for the original hccl definition, please check +# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90 +# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48 + +hcclResult_t = ctypes.c_int +hcclComm_t = ctypes.c_void_p + + +class hcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 4108)] + + +aclrtStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +hcclDataType_t = ctypes.c_int + + +class hcclDataTypeEnum: + hcclInt8 = 0 + hcclInt16 = 1 + hcclInt32 = 2 + hcclFloat16 = 3 + hcclFloat32 = 4 + hcclInt64 = 5 + hcclUint64 = 6 + hcclUint8 = 7 + hcclUint16 = 8 + hcclUint32 = 9 + hcclFloat64 = 10 + hcclBfloat16 = 11 + hcclInt128 = 12 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.hcclInt8 + if dtype == torch.uint8: + return cls.hcclUint8 + if dtype == torch.int32: + return cls.hcclInt32 + if dtype == torch.int64: + return cls.hcclInt64 + if dtype == torch.float16: + return cls.hcclFloat16 + if dtype == torch.float32: + return cls.hcclFloat32 + if dtype == torch.float64: + return cls.hcclFloat64 + if dtype == torch.bfloat16: + return cls.hcclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +hcclRedOp_t = ctypes.c_int + + +class hcclRedOpTypeEnum: + hcclSum = 0 + hcclProd = 1 + hcclMax = 2 + hcclMin = 3 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.hcclSum + if op == ReduceOp.PRODUCT: + return cls.hcclProd + if op == ReduceOp.MAX: + return cls.hcclMax + if op == ReduceOp.MIN: + return cls.hcclMin + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class HCCLLibrary: + exported_functions = [ + # const char* HcclGetErrorString(HcclResult code); + Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]), + + # HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo); + Function("HcclGetRootInfo", hcclResult_t, + [ctypes.POINTER(hcclUniqueId)]), + + # HcclResult HcclCommInitRootInfo( + # uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); + # note that HcclComm is a pointer type, so the last argument is a pointer to a pointer + Function("HcclCommInitRootInfo", hcclResult_t, [ + ctypes.c_int, + ctypes.POINTER(hcclUniqueId), + ctypes.c_int, + ctypes.POINTER(hcclComm_t), + ]), + + # HcclResult HcclAllReduce( + # void *sendBuf, void *recvBuf, uint64_t count, + # HcclDataType dataType, HcclReduceOp op, HcclComm comm, + # aclrtStream stream); + Function("HcclAllReduce", hcclResult_t, [ + buffer_type, + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + hcclRedOp_t, + hcclComm_t, + aclrtStream_t, + ]), + + # HcclResult HcclBroadcast( + # void *buf, uint64_t count, + # HcclDataType dataType, uint32_t root, + # HcclComm comm, aclrtStream stream); + Function("HcclBroadcast", hcclResult_t, [ + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + ctypes.c_int, + hcclComm_t, + aclrtStream_t, + ]), + + # HcclResult HcclCommDestroy(HcclComm comm); + Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the correspongding directory + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_hccl_library() + + try: + if so_file not in HCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + HCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = HCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load HCCL library from %s. " + "It is expected if you are not running on Ascend NPUs." + "Otherwise, the hccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable HCCL_SO_PATH" + " to point to the correct hccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in HCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in HCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + HCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = HCCLLibrary.path_to_dict_mapping[so_file] + + def hcclGetErrorString(self, result: hcclResult_t) -> str: + return self._funcs["HcclGetErrorString"](result).decode("utf-8") + + def HCCL_CHECK(self, result: hcclResult_t) -> None: + if result != 0: + error_str = self.hcclGetErrorString(result) + raise RuntimeError(f"HCCL error: {error_str}") + + def hcclGetUniqueId(self) -> hcclUniqueId: + unique_id = hcclUniqueId() + self.HCCL_CHECK(self._funcs["HcclGetRootInfo"]( + ctypes.byref(unique_id))) + return unique_id + + def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, + rank: int) -> hcclComm_t: + comm = hcclComm_t() + self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"]( + world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))) + return comm + + def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: hcclComm_t, + stream: aclrtStream_t) -> None: + # `datatype` actually should be `hcclDataType_t` + # and `op` should be `hcclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int, + root: int, comm: hcclComm_t, + stream: aclrtStream_t) -> None: + self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, + root, comm, stream)) + + def hcclCommDestroy(self, comm: hcclComm_t) -> None: + self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm)) + + +__all__ = [ + "HCCLLibrary", + "hcclDataTypeEnum", + "hcclRedOpTypeEnum", + "hcclUniqueId", + "hcclComm_t", + "aclrtStream_t", + "buffer_type", +] diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index c5797d9..3ddb15a 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -46,6 +46,8 @@ env_variables: Dict[str, Callable[[], Any]] = { # Used for disaggregated prefilling "HCCN_PATH": lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"), + "HCCL_SO_PATH": + lambda: os.environ.get("HCCL_SO_PATH", None), "PROMPT_DEVICE_ID": lambda: os.getenv("PROMPT_DEVICE_ID", None), "DECODE_DEVICE_ID": @@ -53,7 +55,7 @@ env_variables: Dict[str, Callable[[], Any]] = { "LLMDATADIST_COMM_PORT": lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"), "LLMDATADIST_SYNC_CACHE_WAIT_TIME": - lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000") + lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"), } # end-env-vars-definition diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index dd83d3d..80d3c4d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -17,8 +17,11 @@ # limitations under the License. # import torch +import torch_npu # noqa: F401 from vllm.logger import logger +import vllm_ascend.envs as envs + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib @@ -33,6 +36,28 @@ def try_register_lib(lib_name: str, lib_info: str = ""): pass +def find_hccl_library() -> str: + """ + We either use the library file specified by the `HCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libhccl.so` can be + found by `ctypes` automatically. + """ + so_file = envs.HCCL_SO_PATH + + # manually load the hccl library + if so_file: + logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cann is not None: + so_file = "libhccl.so" + else: + raise ValueError("HCCL only supports Ascend NPU backends.") + logger.info("Found hccl from library %s", so_file) + return so_file + + _current_stream = None