Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)
This commit is contained in:
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
from sglang.srt.utils import is_hpu
|
||||
from sglang.srt.utils import is_hip, is_hpu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
||||
|
||||
if not is_hpu():
|
||||
if use_vllm_custom_allreduce:
|
||||
# Remove vllm dependency for custom allreduce on ROCm
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
@@ -56,7 +56,7 @@ def hint_on_error(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
if use_vllm_custom_allreduce:
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
else:
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
if is_hip():
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.ops.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.ops.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||
|
||||
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,14 +28,27 @@ if is_cuda():
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if is_hip():
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_gpu_board_info,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
ops.meta_size()
|
||||
else:
|
||||
import sgl_kernel
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
if torch.version.hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if is_hip():
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
if is_hip():
|
||||
# crossover is at 16MB buffer size for ROCm
|
||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
max_size=_MAX_CAR_SIZE,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -185,12 +226,9 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if is_cuda():
|
||||
assert is_cuda()
|
||||
if is_cuda() or is_hip():
|
||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
else:
|
||||
full_nvlink = False
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
@@ -201,7 +239,8 @@ class CustomAllreduce:
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not is_hip() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
@@ -214,7 +253,7 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -237,35 +276,56 @@ class CustomAllreduce:
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
if is_hip():
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(
|
||||
max_size, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
)
|
||||
self.register_buffer(self.buffer)
|
||||
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
||||
else:
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
max_size, group=group
|
||||
)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
max_size, group=group
|
||||
)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self.disabled = False
|
||||
|
||||
@staticmethod
|
||||
@@ -316,23 +376,69 @@ class CustomAllreduce:
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
# _share_cuda_() doesn't accept meta buffer not allocated from
|
||||
# PyTorch cache allocator, use direct HIP call to get IPC handle
|
||||
handle = ops.get_meta_buffer_ipc_handle(inp)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
# Note: don't use `[[None]] * self.world_size` here
|
||||
# because it will create a list of the same reference
|
||||
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
|
||||
all_data[self.rank][0] = shard_data
|
||||
|
||||
ranks = dist.get_process_group_ranks(group=self.group)
|
||||
ranks.sort()
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
# we cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0][0]) # type: ignore
|
||||
offsets.append(all_data[i][0][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
if is_hip():
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
else:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [
|
||||
[None, None] for _ in range(dist.get_world_size(group=self.group))
|
||||
]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
@@ -345,11 +451,22 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if is_hip():
|
||||
if self.full_nvlink:
|
||||
if self.world_size == 8:
|
||||
if self.MSCCL:
|
||||
return False
|
||||
else:
|
||||
return inp_size < self.max_size
|
||||
else:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if self.world_size == 2:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
@@ -364,6 +481,21 @@ class CustomAllreduce:
|
||||
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
@@ -397,13 +529,23 @@ class CustomAllreduce:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
if is_hip():
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
if is_hip():
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
return self.all_reduce_unreg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
@@ -411,7 +553,7 @@ class CustomAllreduce:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
else:
|
||||
elif is_cuda():
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||
|
||||
Reference in New Issue
Block a user