[Feature] Integrate quick allreduce and select the best allreduce implementation (#6619)

Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
li haoyang
2025-07-25 11:48:42 +08:00
committed by GitHub
parent f4674df646
commit 28d4d47280
14 changed files with 2031 additions and 109 deletions

View File

@@ -1,6 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
@@ -114,6 +114,34 @@ else:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
# ROCM custom quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
def qr_get_handle(fa: int) -> torch.Tensor:
return sgl_kernel.allreduce.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
sgl_kernel.allreduce.qr_open_handles(fa, handles)
def qr_all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool,
) -> None:
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
def qr_destroy(fa: int) -> None:
sgl_kernel.allreduce.qr_destroy(fa)
def qr_max_size() -> int:
return sgl_kernel.allreduce.qr_max_size()
def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()

View File

@@ -4,18 +4,18 @@ import ctypes
import logging
import os
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from sglang.srt import _custom_ops as ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check,
is_full_nvlink,
is_weak_contiguous,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
@@ -57,70 +40,6 @@ except Exception:
logger = logging.getLogger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024

View File

@@ -8,17 +8,44 @@ import pickle
import subprocess
import sys
import tempfile
from functools import wraps
from itertools import product
from typing import Dict, List, Optional, Sequence
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing_extensions import ParamSpec
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":

View File

@@ -0,0 +1,273 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from enum import Enum
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt import _custom_ops as ops
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
is_full_nvlink,
is_weak_contiguous,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
try:
ops.qr_max_size()
quick_ar = True
except Exception:
# For CPUs and CUDA
quick_ar = False
def qr_rocm_arch_available():
if not _is_hip:
return False
try:
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
except Exception as e:
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
return False
class QuickReduceRegime(Enum):
FP = 0
INT8 = 1
INT6 = 2
INT4 = 3
NONE = 4
MB = 1024 * 1024
class QuickAllReduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# The following data is based on kernel tests.
# In this order [FP, INT8, INT6, INT4].
_QR_MIN_SIZE = {
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(
self, group: ProcessGroup, device: Union[int, str, torch.device]
) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
quantization formats and FP(float16, bfloat16).
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Only the ROCm MI300 series is supported for quick allreduce at
this time.
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self.disabled = True
if not qr_rocm_arch_available():
logger.debug(
"Custom quick allreduce is only supported on ROCm MI300 series."
)
return
if not quick_ar:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info(
"Custom quick allreduce is disabled because "
"of missing custom quick allreduce library"
)
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "Custom quick allreduce should be attached to a non-NCCL group."
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom quick allreduce for
# multi-node case.
logger.warning(
"Custom quick allreduce is disabled because this "
"process group spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
self.rank = rank
self.world_size = world_size
if world_size == 1:
# No need to initialize QuickReduce for single GPU case.
return
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom quick allreduce is disabled due to an "
"unsupported world size: %d. Supported world sizes: %s.",
world_size,
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom quick allreduce is not supported
# this checks hardware and driver support for NVLink
if _is_cuda or _is_hip:
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
if self.world_size > 2 and not self.fully_connected:
logger.debug(
"Custom quick allreduce is disabled because it's not supported "
"on more than two PCIe-only GPUs. "
)
return
self.init_quick_all_reduce()
def init_quick_all_reduce(self):
# On RocM, bfloat16 kernels are slower than fp16
# due to slower match operations
# If environment variable is set to 1, we convert input to fp16
self.use_fp16_kernels = int(
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
)
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
if regime_str not in QuickReduceRegime.__members__:
logger.warning(
"Custom quick allreduce:",
f"Invalid quantization level: {regime_str}. "
"Supported levels: "
f"{list(QuickReduceRegime.__members__.keys())}",
)
return
if regime_str == "NONE":
logger.debug(
"Custom quick allreduce is disabled based "
"on env variable "
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
)
return
self.qr_quant_level = QuickReduceRegime[regime_str]
# TODO: If the dtype is not bfloat16 or then float16,
# quickallreduce should not be created.
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
if qr_max_size > 0:
if qr_max_size < 1:
logger.info(
"You should not set a max_size smaller than 1MB, which can "
"lead to error or degradation to custom allreduce or rccl."
)
qr_max_size = qr_max_size * MB
# If qr_max_size is None, then 2GB is used by default.
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
self.create_shared_buffer()
self.disabled = False
def create_shared_buffer(self):
"""
Creates a shared buffer for quickreduce.
Has to be called after init_custom_qr
"""
handle = ops.qr_get_handle(self._ptr)
world_size = dist.get_world_size(group=self.group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=self.group)
ops.qr_open_handles(self._ptr, handles)
def should_quick_allreduce(self, inp: torch.Tensor):
"""
Check if quickreduce is available
"""
if self.disabled:
return False
if inp.dtype not in self._SUPPORTED_DTYPES:
return False
inp_size = inp.numel() * inp.element_size()
# custom quick allreduce requires input byte size to be
# multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
dtype = inp.dtype
if self.use_fp16_kernels:
dtype = torch.float16
return (
inp_size <= self.qr_max_size
and inp_size
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
)
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
"""Performs an out-of-place custom quick all reduce."""
# quick allreduce doesn't require a separate graph mode,
# as QR uses static IPC buffer.
if out is None:
out = torch.empty_like(inp)
ops.qr_all_reduce(
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
)
return out
def close(self):
if not self.disabled and getattr(self, "_ptr", None):
if ops is not None:
ops.qr_destroy(self._ptr)
self._ptr = 0
self.disabled = True
def __del__(self):
self.close()

View File

@@ -44,6 +44,7 @@ from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
is_cuda_alike,
is_hip,
is_npu,
is_shm_available,
supports_custom_op,
@@ -126,14 +127,18 @@ if supports_custom_op():
fake_impl=inplace_all_reduce_fake,
)
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
def outplace_all_reduce(
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
def outplace_all_reduce_fake(
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
) -> torch.Tensor:
return torch.empty_like(tensor)
direct_register_custom_op(
@@ -264,6 +269,12 @@ class GroupCoordinator:
PyNcclCommunicator,
)
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
qr_rocm_arch_available,
)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
@@ -283,6 +294,7 @@ class GroupCoordinator:
)
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
try:
@@ -295,6 +307,18 @@ class GroupCoordinator:
f"Setup Custom allreduce failed with {e}. To silence this "
"warning, specify --disable-custom-all-reduce explicitly."
)
if is_hip():
try:
# Initialize a custom quick all-reduce implementation for AMD
# when rocm >= gfx942. Quick reduce is designed as a
# complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
if qr_rocm_arch_available():
self.qr_comm = QuickAllReduce(
group=self.cpu_group, device=self.device
)
except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
from sglang.srt.distributed.device_communicators.hpu_communicator import (
HpuCommunicator,
@@ -373,7 +397,8 @@ class GroupCoordinator:
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# We don't need the context of custom quick allreduce because the ipc access
# is already collected in init() and we can capture the quick allreduce directly.
ca_comm = self.ca_comm
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
@@ -388,23 +413,24 @@ class GroupCoordinator:
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# quick allreduce | enabled | enabled |
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note: When custom quick allreduce is enabled, a runtime check
# will be performed. If the tensor size is too small, it will
# automatically fall back to the next available option.
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
# In summary: We select the appropriate allreduce method for
# each mode based on the algorithm order in the table and
# their usage conditions.
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
@@ -464,27 +490,47 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)
outplace_all_reduce_method = None
if (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif (
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
) or (
):
outplace_all_reduce_method = "ca"
elif (
self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
outplace_all_reduce_method = "pymscclpp"
if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce(
input_, group_name=self.unique_name
input_,
group_name=self.unique_name,
outplace_all_reduce_method=outplace_all_reduce_method,
)
else:
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
return input_
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
def _all_reduce_out_place(
self, input_: torch.Tensor, outplace_all_reduce_method: str
) -> torch.Tensor:
qr_comm = self.qr_comm
ca_comm = self.ca_comm
pymscclpp_comm = self.pymscclpp_comm
assert ca_comm is not None or pymscclpp_comm is not None
if ca_comm is not None and not ca_comm.disabled:
assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
else:
assert not pymscclpp_comm.disabled