[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #3) (#5978)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
SILONG ZENG
2026-01-24 22:10:18 +08:00
committed by GitHub
parent 4e53c1d900
commit 7faa6878a6
9 changed files with 953 additions and 1148 deletions

View File

@@ -14,61 +14,50 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: dist.ProcessGroup,
device: torch.device | None = None,
device_group: dist.ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
def all_to_all(
self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: list[int] | None = None,
gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()
if scatter_sizes is not None and gather_sizes is not None:
input_list = [
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
else:
input_list = [
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()

View File

@@ -15,7 +15,6 @@
# limitations under the License.
#
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
HCCLLibrary,
aclrtStream_t,
buffer_type,
hcclComm_t,
hcclDataTypeEnum,
hcclRedOpTypeEnum,
hcclUniqueId,
)
from vllm_ascend.utils import current_stream
class PyHcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
group: ProcessGroup | StatelessProcessGroup,
device: int | str | torch.device,
library_path: str | None = None,
):
"""
Args:
@@ -52,7 +56,8 @@ class PyHcclCommunicator:
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
"PyHcclCommunicator should be attached to a non-HCCL group."
)
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
@@ -113,8 +118,7 @@ class PyHcclCommunicator:
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
@@ -123,43 +127,48 @@ class PyHcclCommunicator:
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
)
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
self.hccl.hcclAllReduce(
buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op),
self.comm,
aclrtStream_t(stream.npu_stream),
)
return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))
self.hccl.hcclBroadcast(
buffer,
tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
aclrtStream_t(stream.npu_stream),
)

View File

@@ -18,7 +18,7 @@
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any
import torch
from torch.distributed import ReduceOp
@@ -107,69 +107,74 @@ class hcclRedOpTypeEnum:
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
Function(
"HcclCommInitRootInfo",
hcclResult_t,
[
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
],
),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
Function(
"HcclAllReduce",
hcclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
Function(
"HcclBroadcast",
hcclResult_t,
[
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: str | None = None):
so_file = so_file or find_hccl_library()
try:
@@ -185,12 +190,14 @@ class HCCLLibrary:
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
" to point to the correct hccl library path.",
so_file,
platform.platform(),
)
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
_funcs: dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
@@ -209,34 +216,37 @@ class HCCLLibrary:
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
self.HCCL_CHECK(
self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))
)
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
def hcclAllReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: hcclComm_t,
stream: aclrtStream_t,
) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclBroadcast(
self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))