### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/mla_v1.py` |
| `vllm_ascend/attention/sfa_v1.py` |
| `vllm_ascend/core/recompute_scheduler.py` |
| `vllm_ascend/core/scheduler_dynamic_batch.py` |
| `vllm_ascend/distributed/device_communicators/npu_communicator.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl.py` |
| `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Co-authored-by: Soren <user@SorendeMac-mini.local>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
HCCLLibrary,
|
||||
aclrtStream_t,
|
||||
buffer_type,
|
||||
hcclComm_t,
|
||||
hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum,
|
||||
hcclUniqueId,
|
||||
)
|
||||
from vllm_ascend.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
group: ProcessGroup | StatelessProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
library_path: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -52,7 +56,8 @@ class PyHcclCommunicator:
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group."
|
||||
)
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
@@ -113,8 +118,7 @@ class PyHcclCommunicator:
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
@@ -123,43 +127,48 @@ class PyHcclCommunicator:
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
|
||||
)
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
self.hccl.hcclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
aclrtStream_t(stream.npu_stream),
|
||||
)
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
self.hccl.hcclBroadcast(
|
||||
buffer,
|
||||
tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
aclrtStream_t(stream.npu_stream),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user