Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)

This commit is contained in:
Hubert Lu
2025-03-02 15:19:06 -08:00
committed by GitHub
parent d3fe9bae56
commit 9cf4077294
9 changed files with 1282 additions and 195 deletions

View File

@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.library
from sglang.srt.utils import is_hpu
from sglang.srt.utils import is_hip, is_hpu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
if not is_hpu():
if use_vllm_custom_allreduce:
# Remove vllm dependency for custom allreduce on ROCm
if use_vllm_custom_allreduce and not is_hip():
try:
import vllm._C
except ImportError as e:
@@ -56,7 +56,7 @@ def hint_on_error(fn):
return wrapper
if use_vllm_custom_allreduce:
if use_vllm_custom_allreduce and not is_hip():
# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
else:
# custom ar
def init_custom_ar(
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
if is_hip():
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return sgl_kernel.ops.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.all_reduce_reg(fa, inp, out)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
def dispose(fa: int) -> None:
sgl_kernel.ops.dispose(fa)
def meta_size() -> int:
return sgl_kernel.ops.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.ops.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
else:
# custom ar
def init_custom_ar(
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456

View File

@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__)
@@ -28,14 +28,27 @@ if is_cuda():
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if is_hip():
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_gpu_board_info,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
try:
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
ops.meta_size()
else:
import sgl_kernel
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
# For CPUs
custom_ar = False
logger = logging.getLogger(__name__)
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
if torch.version.hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if is_hip():
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if is_hip():
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
# max_size: max supported allreduce size
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
max_size=_MAX_CAR_SIZE,
) -> None:
"""
Args:
@@ -185,12 +226,9 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if is_cuda():
assert is_cuda()
if is_cuda() or is_hip():
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
full_nvlink = is_full_nvlink(physical_device_ids)
else:
full_nvlink = False
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
@@ -201,7 +239,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip() and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
@@ -214,7 +253,7 @@ class CustomAllreduce:
self.world_size = world_size
self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
@@ -237,35 +276,56 @@ class CustomAllreduce:
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
if is_hip():
# meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(
max_size, dtype=torch.uint8, device=self.device
)
handle = ops.get_meta_buffer_ipc_handle(self.meta)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
)
self.register_buffer(self.buffer)
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group
)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group
)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self._ptr = ops.init_custom_ar(
rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
self._ptr = ops.init_custom_ar(
rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
self.disabled = False
@staticmethod
@@ -316,23 +376,69 @@ class CustomAllreduce:
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
def _get_ipc_meta(self, inp: torch.Tensor):
# _share_cuda_() doesn't accept meta buffer not allocated from
# PyTorch cache allocator, use direct HIP call to get IPC handle
handle = ops.get_meta_buffer_ipc_handle(inp)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if is_hip():
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
ops.register_graph_buffers(self._ptr, handles, offsets)
else:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [
[None, None] for _ in range(dist.get_world_size(group=self.group))
]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
@@ -345,11 +451,22 @@ class CustomAllreduce:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if is_hip():
if self.full_nvlink:
if self.world_size == 8:
if self.MSCCL:
return False
else:
return inp_size < self.max_size
else:
return inp_size < self.max_size
return False
if self.world_size == 2:
return (
inp_size < self.max_size
@@ -364,6 +481,21 @@ class CustomAllreduce:
return False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def all_reduce(
self,
inp: torch.Tensor,
@@ -397,13 +529,23 @@ class CustomAllreduce:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
if is_hip():
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
return self.all_reduce(input, registered=False)
if is_hip():
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce_unreg(input)
else:
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
@@ -411,7 +553,7 @@ class CustomAllreduce:
if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
else:
elif is_cuda():
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)