[Feature] Integrate quick allreduce and select the best allreduce implementation (#6619)
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -114,6 +114,34 @@ else:
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
# ROCM custom quick allreduce
|
||||
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.qr_get_handle(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
quant_level: int,
|
||||
cast_bf2half: bool,
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
sgl_kernel.allreduce.qr_destroy(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return sgl_kernel.allreduce.qr_max_size()
|
||||
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
||||
|
||||
@@ -4,18 +4,18 @@ import ctypes
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check,
|
||||
is_full_nvlink,
|
||||
is_weak_contiguous,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
@@ -57,70 +40,6 @@ except Exception:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
||||
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
|
||||
@@ -8,17 +8,44 @@ import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
is_full_nvlink,
|
||||
is_weak_contiguous,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
try:
|
||||
ops.qr_max_size()
|
||||
quick_ar = True
|
||||
except Exception:
|
||||
# For CPUs and CUDA
|
||||
quick_ar = False
|
||||
|
||||
|
||||
def qr_rocm_arch_available():
|
||||
if not _is_hip:
|
||||
return False
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
gcn_arch = getattr(props, "gcnArchName", "")
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
FP = 0
|
||||
INT8 = 1
|
||||
INT6 = 2
|
||||
INT4 = 3
|
||||
NONE = 4
|
||||
|
||||
|
||||
MB = 1024 * 1024
|
||||
|
||||
|
||||
class QuickAllReduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
||||
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
# The following data is based on kernel tests.
|
||||
# In this order [FP, INT8, INT6, INT4].
|
||||
_QR_MIN_SIZE = {
|
||||
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
|
||||
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
|
||||
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
|
||||
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
|
||||
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
|
||||
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
||||
) -> None:
|
||||
"""
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
available for CUDA and ROCm MI300 series.
|
||||
Custom quick allreduce leverages quantization for further
|
||||
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
||||
quantization formats and FP(float16, bfloat16).
|
||||
Quick allreduce is designed as a complement to custom allreduce.
|
||||
Its initialization requires even stricter conditions.
|
||||
Only the ROCm MI300 series is supported for quick allreduce at
|
||||
this time.
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self.disabled = True
|
||||
if not qr_rocm_arch_available():
|
||||
logger.debug(
|
||||
"Custom quick allreduce is only supported on ROCm MI300 series."
|
||||
)
|
||||
return
|
||||
|
||||
if not quick_ar:
|
||||
# disable because of missing quick reduce library
|
||||
# e.g. in a cuda environment
|
||||
logger.info(
|
||||
"Custom quick allreduce is disabled because "
|
||||
"of missing custom quick allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "Custom quick allreduce should be attached to a non-NCCL group."
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom quick allreduce for
|
||||
# multi-node case.
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled because this "
|
||||
"process group spans across nodes."
|
||||
)
|
||||
return
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
if world_size == 1:
|
||||
# No need to initialize QuickReduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled due to an "
|
||||
"unsupported world size: %d. Supported world sizes: %s.",
|
||||
world_size,
|
||||
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom quick allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if _is_cuda or _is_hip:
|
||||
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
|
||||
if self.world_size > 2 and not self.fully_connected:
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled because it's not supported "
|
||||
"on more than two PCIe-only GPUs. "
|
||||
)
|
||||
return
|
||||
|
||||
self.init_quick_all_reduce()
|
||||
|
||||
def init_quick_all_reduce(self):
|
||||
# On RocM, bfloat16 kernels are slower than fp16
|
||||
# due to slower match operations
|
||||
# If environment variable is set to 1, we convert input to fp16
|
||||
self.use_fp16_kernels = int(
|
||||
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
|
||||
)
|
||||
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
|
||||
if regime_str not in QuickReduceRegime.__members__:
|
||||
logger.warning(
|
||||
"Custom quick allreduce:",
|
||||
f"Invalid quantization level: {regime_str}. "
|
||||
"Supported levels: "
|
||||
f"{list(QuickReduceRegime.__members__.keys())}",
|
||||
)
|
||||
return
|
||||
|
||||
if regime_str == "NONE":
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled based "
|
||||
"on env variable "
|
||||
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
||||
)
|
||||
return
|
||||
self.qr_quant_level = QuickReduceRegime[regime_str]
|
||||
|
||||
# TODO: If the dtype is not bfloat16 or then float16,
|
||||
# quickallreduce should not be created.
|
||||
|
||||
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
||||
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
|
||||
if qr_max_size > 0:
|
||||
if qr_max_size < 1:
|
||||
logger.info(
|
||||
"You should not set a max_size smaller than 1MB, which can "
|
||||
"lead to error or degradation to custom allreduce or rccl."
|
||||
)
|
||||
qr_max_size = qr_max_size * MB
|
||||
# If qr_max_size is None, then 2GB is used by default.
|
||||
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
||||
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
|
||||
self.create_shared_buffer()
|
||||
self.disabled = False
|
||||
|
||||
def create_shared_buffer(self):
|
||||
"""
|
||||
Creates a shared buffer for quickreduce.
|
||||
Has to be called after init_custom_qr
|
||||
"""
|
||||
handle = ops.qr_get_handle(self._ptr)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=self.group)
|
||||
ops.qr_open_handles(self._ptr, handles)
|
||||
|
||||
def should_quick_allreduce(self, inp: torch.Tensor):
|
||||
"""
|
||||
Check if quickreduce is available
|
||||
"""
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype not in self._SUPPORTED_DTYPES:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom quick allreduce requires input byte size to be
|
||||
# multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
dtype = inp.dtype
|
||||
if self.use_fp16_kernels:
|
||||
dtype = torch.float16
|
||||
return (
|
||||
inp_size <= self.qr_max_size
|
||||
and inp_size
|
||||
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
||||
)
|
||||
|
||||
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
||||
"""Performs an out-of-place custom quick all reduce."""
|
||||
# quick allreduce doesn't require a separate graph mode,
|
||||
# as QR uses static IPC buffer.
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.qr_all_reduce(
|
||||
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
||||
)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and getattr(self, "_ptr", None):
|
||||
if ops is not None:
|
||||
ops.qr_destroy(self._ptr)
|
||||
self._ptr = 0
|
||||
self.disabled = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda_alike,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_shm_available,
|
||||
supports_custom_op,
|
||||
@@ -126,14 +127,18 @@ if supports_custom_op():
|
||||
fake_impl=inplace_all_reduce_fake,
|
||||
)
|
||||
|
||||
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
def outplace_all_reduce(
|
||||
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce_out_place(tensor)
|
||||
return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
|
||||
|
||||
def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
def outplace_all_reduce_fake(
|
||||
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -264,6 +269,12 @@ class GroupCoordinator:
|
||||
PyNcclCommunicator,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
qr_rocm_arch_available,
|
||||
)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
@@ -283,6 +294,7 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
try:
|
||||
@@ -295,6 +307,18 @@ class GroupCoordinator:
|
||||
f"Setup Custom allreduce failed with {e}. To silence this "
|
||||
"warning, specify --disable-custom-all-reduce explicitly."
|
||||
)
|
||||
if is_hip():
|
||||
try:
|
||||
# Initialize a custom quick all-reduce implementation for AMD
|
||||
# when rocm >= gfx942. Quick reduce is designed as a
|
||||
# complement to custom allreduce.
|
||||
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
||||
if qr_rocm_arch_available():
|
||||
self.qr_comm = QuickAllReduce(
|
||||
group=self.cpu_group, device=self.device
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||
|
||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||
HpuCommunicator,
|
||||
@@ -373,7 +397,8 @@ class GroupCoordinator:
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# We don't need the context of custom quick allreduce because the ipc access
|
||||
# is already collected in init() and we can capture the quick allreduce directly.
|
||||
ca_comm = self.ca_comm
|
||||
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
|
||||
@@ -388,23 +413,24 @@ class GroupCoordinator:
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# quick allreduce | enabled | enabled |
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note: When custom quick allreduce is enabled, a runtime check
|
||||
# will be performed. If the tensor size is too small, it will
|
||||
# automatically fall back to the next available option.
|
||||
# Note that custom allreduce will have a runtime check, if the
|
||||
# tensor size is too large, it will fallback to the next
|
||||
# available option.
|
||||
# Note that the PyMsccl needs to register the tensor in ahead,
|
||||
# which will introduce large overhead in the eager case,
|
||||
# therefore it is only supported in the graph case.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using
|
||||
# CUDA graph, we use either custom all-reduce kernel or
|
||||
# PyTorch NCCL. We always prioritize using custom all-reduce
|
||||
# kernel but fall back to PyTorch or pynccl if it is
|
||||
# disabled or not supported.
|
||||
# In summary: We select the appropriate allreduce method for
|
||||
# each mode based on the algorithm order in the table and
|
||||
# their usage conditions.
|
||||
pynccl_comm = self.pynccl_comm
|
||||
maybe_pynccl_context: Any
|
||||
if not pynccl_comm:
|
||||
@@ -464,27 +490,47 @@ class GroupCoordinator:
|
||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
outplace_all_reduce_method = None
|
||||
if (
|
||||
self.qr_comm is not None
|
||||
and not self.qr_comm.disabled
|
||||
and self.qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "qr"
|
||||
elif (
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
and self.ca_comm.should_custom_ar(input_)
|
||||
) or (
|
||||
):
|
||||
outplace_all_reduce_method = "ca"
|
||||
elif (
|
||||
self.pymscclpp_comm is not None
|
||||
and not self.pymscclpp_comm.disabled
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "pymscclpp"
|
||||
if outplace_all_reduce_method is not None:
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name
|
||||
input_,
|
||||
group_name=self.unique_name,
|
||||
outplace_all_reduce_method=outplace_all_reduce_method,
|
||||
)
|
||||
else:
|
||||
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
|
||||
return input_
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
def _all_reduce_out_place(
|
||||
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
qr_comm = self.qr_comm
|
||||
ca_comm = self.ca_comm
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
assert ca_comm is not None or pymscclpp_comm is not None
|
||||
if ca_comm is not None and not ca_comm.disabled:
|
||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||
if outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "ca":
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
|
||||
Reference in New Issue
Block a user