Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)
This commit is contained in:
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
from sglang.srt.utils import is_hpu
|
||||
from sglang.srt.utils import is_hip, is_hpu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
||||
|
||||
if not is_hpu():
|
||||
if use_vllm_custom_allreduce:
|
||||
# Remove vllm dependency for custom allreduce on ROCm
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
@@ -56,7 +56,7 @@ def hint_on_error(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
if use_vllm_custom_allreduce:
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
else:
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
if is_hip():
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.ops.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.ops.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# custom ar
|
||||
def init_custom_ar(
|
||||
rank_id: int,
|
||||
world_size: int,
|
||||
rank_data_base: torch.Tensor,
|
||||
buffers: List[int],
|
||||
tmp_result_buffers: List[int],
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
buffers,
|
||||
tmp_result_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||
|
||||
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,14 +28,27 @@ if is_cuda():
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if is_hip():
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_gpu_board_info,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
ops.meta_size()
|
||||
else:
|
||||
import sgl_kernel
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
if torch.version.hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if is_hip():
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
if is_hip():
|
||||
# crossover is at 16MB buffer size for ROCm
|
||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
max_size=_MAX_CAR_SIZE,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -185,12 +226,9 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if is_cuda():
|
||||
assert is_cuda()
|
||||
if is_cuda() or is_hip():
|
||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
else:
|
||||
full_nvlink = False
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
@@ -201,7 +239,8 @@ class CustomAllreduce:
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not is_hip() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
@@ -214,7 +253,7 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -237,35 +276,56 @@ class CustomAllreduce:
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
if is_hip():
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(
|
||||
max_size, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
)
|
||||
self.register_buffer(self.buffer)
|
||||
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
||||
else:
|
||||
# From TensorRT-LLM getMaxRequiredWorkspaceSize
|
||||
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
|
||||
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
|
||||
self.barrier_max_size = 8 * (36 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
max_size, group=group
|
||||
)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
max_size, group=group
|
||||
)
|
||||
self.rank_data_base = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(
|
||||
self.barrier_max_size, group=group
|
||||
)
|
||||
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data_base,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
self.disabled = False
|
||||
|
||||
@staticmethod
|
||||
@@ -316,23 +376,69 @@ class CustomAllreduce:
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
# _share_cuda_() doesn't accept meta buffer not allocated from
|
||||
# PyTorch cache allocator, use direct HIP call to get IPC handle
|
||||
handle = ops.get_meta_buffer_ipc_handle(inp)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
# Note: don't use `[[None]] * self.world_size` here
|
||||
# because it will create a list of the same reference
|
||||
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
|
||||
all_data[self.rank][0] = shard_data
|
||||
|
||||
ranks = dist.get_process_group_ranks(group=self.group)
|
||||
ranks.sort()
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
# we cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0][0]) # type: ignore
|
||||
offsets.append(all_data[i][0][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
if is_hip():
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
else:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [
|
||||
[None, None] for _ in range(dist.get_world_size(group=self.group))
|
||||
]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
@@ -345,11 +451,22 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip():
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if is_hip():
|
||||
if self.full_nvlink:
|
||||
if self.world_size == 8:
|
||||
if self.MSCCL:
|
||||
return False
|
||||
else:
|
||||
return inp_size < self.max_size
|
||||
else:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if self.world_size == 2:
|
||||
return (
|
||||
inp_size < self.max_size
|
||||
@@ -364,6 +481,21 @@ class CustomAllreduce:
|
||||
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
@@ -397,13 +529,23 @@ class CustomAllreduce:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
if is_hip():
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
if is_hip():
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
return self.all_reduce_unreg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
@@ -411,7 +553,7 @@ class CustomAllreduce:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
else:
|
||||
elif is_cuda():
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||
|
||||
@@ -44,6 +44,7 @@ include_dirs = [
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension_rocm.cc",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/custom_all_reduce.hip",
|
||||
]
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
|
||||
@@ -1,74 +1,138 @@
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
ctypes.CDLL(
|
||||
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
|
||||
mode=ctypes.RTLD_GLOBAL,
|
||||
)
|
||||
|
||||
from sgl_kernel.ops import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
cublas_grouped_gemm,
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
sampling_scaling_penalties,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
|
||||
from .version import __version__
|
||||
|
||||
__all__ = [
|
||||
"apply_rope_with_cos_sin_cache_inplace",
|
||||
"bmm_fp8",
|
||||
"cublas_grouped_gemm",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"fp8_blockwise_scaled_mm",
|
||||
"fp8_scaled_mm",
|
||||
"fused_add_rmsnorm",
|
||||
"gelu_and_mul",
|
||||
"gelu_tanh_and_mul",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"init_custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"lightning_attention_decode",
|
||||
"min_p_sampling_from_probs",
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"sampling_scaling_penalties",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
"tree_speculative_sampling_target_only",
|
||||
"build_tree_kernel_efficient",
|
||||
"build_tree_kernel",
|
||||
"sgl_per_token_group_quant_fp8",
|
||||
]
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.ops import (
|
||||
all_reduce_reg,
|
||||
all_reduce_unreg,
|
||||
allocate_meta_buffer,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
dispose,
|
||||
fp8_scaled_mm,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
get_meta_buffer_ipc_handle,
|
||||
init_custom_ar,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
meta_size,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_buffer,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
sampling_scaling_penalties,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"all_reduce_reg",
|
||||
"all_reduce_unreg",
|
||||
"allocate_meta_buffer",
|
||||
"apply_rope_with_cos_sin_cache_inplace",
|
||||
"bmm_fp8",
|
||||
"dispose",
|
||||
"fp8_scaled_mm",
|
||||
"fused_add_rmsnorm",
|
||||
"gelu_and_mul",
|
||||
"gelu_tanh_and_mul",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"get_meta_buffer_ipc_handle",
|
||||
"init_custom_ar",
|
||||
"int8_scaled_mm",
|
||||
"lightning_attention_decode",
|
||||
"meta_size",
|
||||
"min_p_sampling_from_probs",
|
||||
"moe_align_block_size",
|
||||
"register_buffer",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"sampling_scaling_penalties",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
]
|
||||
else:
|
||||
from sgl_kernel.ops import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
cublas_grouped_gemm,
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
sampling_scaling_penalties,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_rope_with_cos_sin_cache_inplace",
|
||||
"bmm_fp8",
|
||||
"cublas_grouped_gemm",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"fp8_blockwise_scaled_mm",
|
||||
"fp8_scaled_mm",
|
||||
"fused_add_rmsnorm",
|
||||
"gelu_and_mul",
|
||||
"gelu_tanh_and_mul",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"init_custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"lightning_attention_decode",
|
||||
"min_p_sampling_from_probs",
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"sampling_scaling_penalties",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
"tree_speculative_sampling_target_only",
|
||||
"build_tree_kernel_efficient",
|
||||
"build_tree_kernel",
|
||||
"sgl_per_token_group_quant_fp8",
|
||||
]
|
||||
|
||||
180
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
Normal file
180
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
Normal file
@@ -0,0 +1,180 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/hip/Exceptions.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
|
||||
// fake pointer type, must match fptr_t type in ops.h
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
hipIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
hipStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, hipMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handles =
|
||||
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||
return {handles, std::move(offsets)};
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
|
||||
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
|
||||
inp.data_ptr()));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
torch::Tensor allocate_meta_buffer(int64_t size) {
|
||||
auto device_index = c10::hip::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(hipStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(torch::kI8)
|
||||
.device(torch::kCUDA, device_index);
|
||||
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> get_device_bdf(int dev) {
|
||||
char busIdStr[] = "0000:00:00.0";
|
||||
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
|
||||
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
|
||||
bdf.resize(bdf.size() - 1); // remove trailing NULL
|
||||
return bdf;
|
||||
}
|
||||
554
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
Normal file
554
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
Normal file
@@ -0,0 +1,554 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
hipError_t e = cmd; \
|
||||
if (e != hipSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
struct __align__(16) RankData {
|
||||
const void* ptrs[8];
|
||||
};
|
||||
#else
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
#endif
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
#ifdef USE_ROCM
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
#else
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
#endif
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets, int rank, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle),
|
||||
hipIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(hipIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (hipPointerGetAttribute(&base_ptr,
|
||||
#ifdef USE_ROCM
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#else
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#endif
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error("Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(hipStream_t stream, T* input, T* output, int size,
|
||||
#ifndef USE_ROCM
|
||||
int threads = 512, int block_limit = 36){
|
||||
#else
|
||||
int threads = 512, int block_limit = 16) {
|
||||
#endif
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
hipStreamCaptureStatus status;
|
||||
CUDACHECK(hipStreamIsCapturing(stream, &status));
|
||||
if (status == hipStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = ::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \
|
||||
size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
}; // namespace vllm
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
@@ -34,8 +34,23 @@ limitations under the License.
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
// trt_reduce
|
||||
using fptr_t = int64_t;
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets);
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
#else
|
||||
// trt_reduce
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out);
|
||||
@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
// moe_align_block_size
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||
|
||||
@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
)
|
||||
|
||||
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
if torch.version.hip is not None:
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernels.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# trt_reduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
)
|
||||
):
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
buffers,
|
||||
tmp_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
|
||||
@@ -19,6 +19,37 @@ limitations under the License.
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
// Custom all-reduce kernels
|
||||
m.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
m.def(
|
||||
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||
"()");
|
||||
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
|
||||
m.def("dispose", &dispose);
|
||||
|
||||
m.def("meta_size", &meta_size);
|
||||
|
||||
m.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
m.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
m.def("allocate_meta_buffer", &allocate_meta_buffer);
|
||||
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
|
||||
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
|
||||
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
|
||||
|
||||
// moe_align_block_size
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
|
||||
Reference in New Issue
Block a user