### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
@@ -107,69 +107,74 @@ class hcclRedOpTypeEnum:
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]),
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
Function(
|
||||
"HcclCommInitRootInfo",
|
||||
hcclResult_t,
|
||||
[
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
],
|
||||
),
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
Function(
|
||||
"HcclAllReduce",
|
||||
hcclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
],
|
||||
),
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
Function(
|
||||
"HcclBroadcast",
|
||||
hcclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
],
|
||||
),
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: str | None = None):
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
@@ -185,12 +190,14 @@ class HCCLLibrary:
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
" to point to the correct hccl library path.",
|
||||
so_file,
|
||||
platform.platform(),
|
||||
)
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
@@ -209,34 +216,37 @@ class HCCLLibrary:
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
self.HCCL_CHECK(
|
||||
self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))
|
||||
)
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
def hcclAllReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: hcclComm_t,
|
||||
stream: aclrtStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
def hcclBroadcast(
|
||||
self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
|
||||
) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
Reference in New Issue
Block a user