[Feature] Integrate quick allreduce and select the best allreduce implementation (#6619)
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -114,6 +114,34 @@ else:
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
# ROCM custom quick allreduce
|
||||
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.qr_get_handle(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
quant_level: int,
|
||||
cast_bf2half: bool,
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
sgl_kernel.allreduce.qr_destroy(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return sgl_kernel.allreduce.qr_max_size()
|
||||
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
||||
|
||||
@@ -4,18 +4,18 @@ import ctypes
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check,
|
||||
is_full_nvlink,
|
||||
is_weak_contiguous,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
@@ -57,70 +40,6 @@ except Exception:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
||||
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
|
||||
@@ -8,17 +8,44 @@ import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
is_full_nvlink,
|
||||
is_weak_contiguous,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
try:
|
||||
ops.qr_max_size()
|
||||
quick_ar = True
|
||||
except Exception:
|
||||
# For CPUs and CUDA
|
||||
quick_ar = False
|
||||
|
||||
|
||||
def qr_rocm_arch_available():
|
||||
if not _is_hip:
|
||||
return False
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
gcn_arch = getattr(props, "gcnArchName", "")
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
FP = 0
|
||||
INT8 = 1
|
||||
INT6 = 2
|
||||
INT4 = 3
|
||||
NONE = 4
|
||||
|
||||
|
||||
MB = 1024 * 1024
|
||||
|
||||
|
||||
class QuickAllReduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
||||
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
# The following data is based on kernel tests.
|
||||
# In this order [FP, INT8, INT6, INT4].
|
||||
_QR_MIN_SIZE = {
|
||||
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
|
||||
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
|
||||
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
|
||||
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
|
||||
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
|
||||
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
||||
) -> None:
|
||||
"""
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
available for CUDA and ROCm MI300 series.
|
||||
Custom quick allreduce leverages quantization for further
|
||||
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
||||
quantization formats and FP(float16, bfloat16).
|
||||
Quick allreduce is designed as a complement to custom allreduce.
|
||||
Its initialization requires even stricter conditions.
|
||||
Only the ROCm MI300 series is supported for quick allreduce at
|
||||
this time.
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self.disabled = True
|
||||
if not qr_rocm_arch_available():
|
||||
logger.debug(
|
||||
"Custom quick allreduce is only supported on ROCm MI300 series."
|
||||
)
|
||||
return
|
||||
|
||||
if not quick_ar:
|
||||
# disable because of missing quick reduce library
|
||||
# e.g. in a cuda environment
|
||||
logger.info(
|
||||
"Custom quick allreduce is disabled because "
|
||||
"of missing custom quick allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "Custom quick allreduce should be attached to a non-NCCL group."
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom quick allreduce for
|
||||
# multi-node case.
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled because this "
|
||||
"process group spans across nodes."
|
||||
)
|
||||
return
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
if world_size == 1:
|
||||
# No need to initialize QuickReduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled due to an "
|
||||
"unsupported world size: %d. Supported world sizes: %s.",
|
||||
world_size,
|
||||
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom quick allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if _is_cuda or _is_hip:
|
||||
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
|
||||
if self.world_size > 2 and not self.fully_connected:
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled because it's not supported "
|
||||
"on more than two PCIe-only GPUs. "
|
||||
)
|
||||
return
|
||||
|
||||
self.init_quick_all_reduce()
|
||||
|
||||
def init_quick_all_reduce(self):
|
||||
# On RocM, bfloat16 kernels are slower than fp16
|
||||
# due to slower match operations
|
||||
# If environment variable is set to 1, we convert input to fp16
|
||||
self.use_fp16_kernels = int(
|
||||
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
|
||||
)
|
||||
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
|
||||
if regime_str not in QuickReduceRegime.__members__:
|
||||
logger.warning(
|
||||
"Custom quick allreduce:",
|
||||
f"Invalid quantization level: {regime_str}. "
|
||||
"Supported levels: "
|
||||
f"{list(QuickReduceRegime.__members__.keys())}",
|
||||
)
|
||||
return
|
||||
|
||||
if regime_str == "NONE":
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled based "
|
||||
"on env variable "
|
||||
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
||||
)
|
||||
return
|
||||
self.qr_quant_level = QuickReduceRegime[regime_str]
|
||||
|
||||
# TODO: If the dtype is not bfloat16 or then float16,
|
||||
# quickallreduce should not be created.
|
||||
|
||||
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
||||
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
|
||||
if qr_max_size > 0:
|
||||
if qr_max_size < 1:
|
||||
logger.info(
|
||||
"You should not set a max_size smaller than 1MB, which can "
|
||||
"lead to error or degradation to custom allreduce or rccl."
|
||||
)
|
||||
qr_max_size = qr_max_size * MB
|
||||
# If qr_max_size is None, then 2GB is used by default.
|
||||
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
||||
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
|
||||
self.create_shared_buffer()
|
||||
self.disabled = False
|
||||
|
||||
def create_shared_buffer(self):
|
||||
"""
|
||||
Creates a shared buffer for quickreduce.
|
||||
Has to be called after init_custom_qr
|
||||
"""
|
||||
handle = ops.qr_get_handle(self._ptr)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=self.group)
|
||||
ops.qr_open_handles(self._ptr, handles)
|
||||
|
||||
def should_quick_allreduce(self, inp: torch.Tensor):
|
||||
"""
|
||||
Check if quickreduce is available
|
||||
"""
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype not in self._SUPPORTED_DTYPES:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom quick allreduce requires input byte size to be
|
||||
# multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
dtype = inp.dtype
|
||||
if self.use_fp16_kernels:
|
||||
dtype = torch.float16
|
||||
return (
|
||||
inp_size <= self.qr_max_size
|
||||
and inp_size
|
||||
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
||||
)
|
||||
|
||||
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
||||
"""Performs an out-of-place custom quick all reduce."""
|
||||
# quick allreduce doesn't require a separate graph mode,
|
||||
# as QR uses static IPC buffer.
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.qr_all_reduce(
|
||||
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
||||
)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and getattr(self, "_ptr", None):
|
||||
if ops is not None:
|
||||
ops.qr_destroy(self._ptr)
|
||||
self._ptr = 0
|
||||
self.disabled = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda_alike,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_shm_available,
|
||||
supports_custom_op,
|
||||
@@ -126,14 +127,18 @@ if supports_custom_op():
|
||||
fake_impl=inplace_all_reduce_fake,
|
||||
)
|
||||
|
||||
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
def outplace_all_reduce(
|
||||
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce_out_place(tensor)
|
||||
return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
|
||||
|
||||
def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
def outplace_all_reduce_fake(
|
||||
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -264,6 +269,12 @@ class GroupCoordinator:
|
||||
PyNcclCommunicator,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
qr_rocm_arch_available,
|
||||
)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
@@ -283,6 +294,7 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
try:
|
||||
@@ -295,6 +307,18 @@ class GroupCoordinator:
|
||||
f"Setup Custom allreduce failed with {e}. To silence this "
|
||||
"warning, specify --disable-custom-all-reduce explicitly."
|
||||
)
|
||||
if is_hip():
|
||||
try:
|
||||
# Initialize a custom quick all-reduce implementation for AMD
|
||||
# when rocm >= gfx942. Quick reduce is designed as a
|
||||
# complement to custom allreduce.
|
||||
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
||||
if qr_rocm_arch_available():
|
||||
self.qr_comm = QuickAllReduce(
|
||||
group=self.cpu_group, device=self.device
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||
|
||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||
HpuCommunicator,
|
||||
@@ -373,7 +397,8 @@ class GroupCoordinator:
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# We don't need the context of custom quick allreduce because the ipc access
|
||||
# is already collected in init() and we can capture the quick allreduce directly.
|
||||
ca_comm = self.ca_comm
|
||||
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
|
||||
@@ -388,23 +413,24 @@ class GroupCoordinator:
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# quick allreduce | enabled | enabled |
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note: When custom quick allreduce is enabled, a runtime check
|
||||
# will be performed. If the tensor size is too small, it will
|
||||
# automatically fall back to the next available option.
|
||||
# Note that custom allreduce will have a runtime check, if the
|
||||
# tensor size is too large, it will fallback to the next
|
||||
# available option.
|
||||
# Note that the PyMsccl needs to register the tensor in ahead,
|
||||
# which will introduce large overhead in the eager case,
|
||||
# therefore it is only supported in the graph case.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using
|
||||
# CUDA graph, we use either custom all-reduce kernel or
|
||||
# PyTorch NCCL. We always prioritize using custom all-reduce
|
||||
# kernel but fall back to PyTorch or pynccl if it is
|
||||
# disabled or not supported.
|
||||
# In summary: We select the appropriate allreduce method for
|
||||
# each mode based on the algorithm order in the table and
|
||||
# their usage conditions.
|
||||
pynccl_comm = self.pynccl_comm
|
||||
maybe_pynccl_context: Any
|
||||
if not pynccl_comm:
|
||||
@@ -464,27 +490,47 @@ class GroupCoordinator:
|
||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
outplace_all_reduce_method = None
|
||||
if (
|
||||
self.qr_comm is not None
|
||||
and not self.qr_comm.disabled
|
||||
and self.qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "qr"
|
||||
elif (
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
and self.ca_comm.should_custom_ar(input_)
|
||||
) or (
|
||||
):
|
||||
outplace_all_reduce_method = "ca"
|
||||
elif (
|
||||
self.pymscclpp_comm is not None
|
||||
and not self.pymscclpp_comm.disabled
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "pymscclpp"
|
||||
if outplace_all_reduce_method is not None:
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name
|
||||
input_,
|
||||
group_name=self.unique_name,
|
||||
outplace_all_reduce_method=outplace_all_reduce_method,
|
||||
)
|
||||
else:
|
||||
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
|
||||
return input_
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
def _all_reduce_out_place(
|
||||
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
||||
) -> torch.Tensor:
|
||||
qr_comm = self.qr_comm
|
||||
ca_comm = self.ca_comm
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
assert ca_comm is not None or pymscclpp_comm is not None
|
||||
if ca_comm is not None and not ca_comm.disabled:
|
||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||
if outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "ca":
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
|
||||
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quick_all_reduce.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(
|
||||
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
||||
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
@@ -0,0 +1,633 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "quick_all_reduce_base.h"
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
struct CodecBase {
|
||||
const int thread;
|
||||
const int rank;
|
||||
const int group_leader;
|
||||
__quickreduce_device_inline__ CodecBase(int thread, int rank)
|
||||
: thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
|
||||
set_fp16_ovfl(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Default full precision codec.
|
||||
template <typename T, int world_size>
|
||||
struct CodecFP : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each thread processes atoms of f16x8_t (16B).
|
||||
static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t);
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
__quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
__builtin_nontemporal_store(data[i], send_buffer + thread);
|
||||
send_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
|
||||
*recv_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int4 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ4 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1152;
|
||||
static constexpr int kRankTileScaleOffset = 1024;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/8.0h, -1/8.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-8, -8}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
|
||||
|
||||
// {+7, +7}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
|
||||
|
||||
// {+8, +8}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00080008;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q4 into int32_t
|
||||
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q4 into f16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static constexpr uint kMask000F = 0x000F000F;
|
||||
static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q4, kHalf2_1032);
|
||||
} else {
|
||||
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int6 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ6 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1664;
|
||||
static constexpr int kRankTileQ2Offset = 1024;
|
||||
static constexpr int kRankTileScaleOffset = 1536;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/32.0h, -1/32.0h}, fp16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
|
||||
|
||||
// {1e-7, 1e-7}, fp16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-32, -32}, fp16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
|
||||
|
||||
// {+31, +31}, fp16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
|
||||
|
||||
// {+32, +32}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00200020;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q6 into int32_t + int16_t
|
||||
uint32_t q4w;
|
||||
uint16_t q2w = 0;
|
||||
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
|
||||
{
|
||||
int16_t* tw = reinterpret_cast<int16_t*>(&q);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q2w |= (tw[i] >> 4) << (i * 2);
|
||||
}
|
||||
}
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(q4w, q4w_ptr);
|
||||
__builtin_nontemporal_store(q2w, q2w_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
|
||||
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q6 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask000F = 0x000F000F;
|
||||
static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int32_t q4 = q4w & kMask000F;
|
||||
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
|
||||
q4w >>= 4;
|
||||
q2w >>= 4;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056));
|
||||
} else {
|
||||
int32_t int16_2 = q4 | (q2 << 4);
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
// That's pretty much it...
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int8 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ8 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of f16x8_t (16B),
|
||||
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 2176;
|
||||
static constexpr int kRankTileScaleOffset = 2048;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/128.0h, -1/128.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-128, -128}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
|
||||
// {+127, +127}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
|
||||
|
||||
// {+128, +128}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00800080;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q8 into int32x2_t
|
||||
int32x2_t qw;
|
||||
qw[0] = q[0] | (q[1] << 8);
|
||||
qw[1] = q[2] | (q[3] << 8);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q8 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask00FF = 0x00FF00FF;
|
||||
|
||||
// {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1024 = 0x64006400;
|
||||
|
||||
// {-1152.0, -1152.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1152 = 0xE480E480;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q8, kHalf2_1152);
|
||||
} else {
|
||||
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Twoshot All Reduce
|
||||
template <typename T, class Codec, bool cast_bf2half>
|
||||
struct AllReduceTwoshot {
|
||||
static_assert(sizeof(T) == 2);
|
||||
|
||||
static constexpr int kWorldSize = Codec::kWorldSize;
|
||||
|
||||
__device__ static void
|
||||
run(T const* __restrict__ input,
|
||||
T* __restrict__ output,
|
||||
uint32_t const N, // number of elements
|
||||
int const block, // block index
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
|
||||
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
|
||||
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
|
||||
src_offset += kAtomStride * sizeof(int32x4_t);
|
||||
if constexpr (cast_bf2half) {
|
||||
const nv_bfloat162* bf_buf = reinterpret_cast<const nv_bfloat162*>(&tA[i]);
|
||||
half2 half_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __bfloat1622float2(bf_buf[j]);
|
||||
half_buf[j] = __float22half2_rn(f);
|
||||
}
|
||||
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Phase-1A: Write segment data into the communication buffer of the target
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
// --------------------------------------------------------
|
||||
// Phase-1B: Reduce the segment data from the communication buffers.
|
||||
int32x4_t tR[Codec::kRankAtoms] = {};
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// note: we reuse tA as temp buffer here
|
||||
codec.recv(&recv_buffer, tA);
|
||||
|
||||
for (int i = 0; i < Codec::kRankAtoms; i++) {
|
||||
packed_assign_add<T>(&tR[i], &tA[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase-2: Write the reduced segment to every other rank
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, tR);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
|
||||
// Phase-2: Read the gather segments from the rank's communication buffer.
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Gather all reduced and final rank segments into tA.
|
||||
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Write the result to output.
|
||||
BufferResource dst_buffer(output, N * sizeof(T));
|
||||
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
if constexpr (cast_bf2half) {
|
||||
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
|
||||
nv_bfloat162 bf16_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __half22float2(half_buf[j]);
|
||||
bf16_buf[j] = __float22bfloat162_rn(f);
|
||||
}
|
||||
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
} else {
|
||||
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
}
|
||||
dst_offset += kAtomStride * sizeof(int32x4_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
@@ -0,0 +1,233 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "quick_all_reduce.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
|
||||
T const* A,
|
||||
T* B,
|
||||
uint32_t N,
|
||||
uint32_t num_blocks,
|
||||
int rank,
|
||||
uint8_t** dbuffer_list,
|
||||
uint32_t data_offset,
|
||||
uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() {
|
||||
return world_size;
|
||||
}
|
||||
int get_rank() {
|
||||
return rank;
|
||||
}
|
||||
bool status() {
|
||||
return initialized;
|
||||
}
|
||||
hipIpcMemHandle_t const get_handle() {
|
||||
return buffer_ipc_handle;
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(
|
||||
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(cudaGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
@@ -0,0 +1,318 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define __quickreduce_device_inline__ __device__ __forceinline__
|
||||
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
|
||||
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
|
||||
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
|
||||
// as per architecture.
|
||||
#if defined(__gfx942__)
|
||||
// CDNA3: Scope bits sc0, sc1
|
||||
#define MUBUF_ACQUIRE 16
|
||||
#define MUBUF_RELEASE 16
|
||||
#elif (defined(__gfx908__) || defined(__gfx90a__))
|
||||
// CDNA1 and CDNA2 - glc bit
|
||||
#define MUBUF_ACQUIRE 1
|
||||
#define MUBUF_RELEASE 0
|
||||
#endif
|
||||
|
||||
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
|
||||
|
||||
// Number of atoms (4xf16x2_t) processed by a single thread
|
||||
static constexpr int kAtoms = 8;
|
||||
|
||||
// We use a workgroup of 256 threads
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kAtomStride = kBlockSize;
|
||||
|
||||
// Size and atom stride of source/destination data that the block will
|
||||
// process.
|
||||
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
|
||||
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
|
||||
|
||||
// Max number of blocks. 304 CUs on MI300
|
||||
static constexpr int kMaxNumBlocks = 304 * 4;
|
||||
|
||||
// Standard CDNA wavefront size.
|
||||
static constexpr int kWavefront = 64;
|
||||
|
||||
// 256 thread, 4 wavefronts.
|
||||
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
|
||||
|
||||
// Number of threads in a group for quantization
|
||||
// It corresponds to 32 F16 elements in quantization block
|
||||
static constexpr int kThreadGroupSize = 8;
|
||||
|
||||
// Methods
|
||||
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) {
|
||||
return ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
union BufferResource {
|
||||
__quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {}
|
||||
|
||||
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size)
|
||||
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
|
||||
|
||||
int32x4_t descriptor;
|
||||
struct {
|
||||
void* address; // 8B, out of which first 48b is address, and 16b is stride
|
||||
// (unused)
|
||||
uint32_t range; // Byte range for the buffer resource
|
||||
uint32_t config; // Constant, DFMT=32b
|
||||
};
|
||||
};
|
||||
|
||||
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
|
||||
int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void
|
||||
buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm(
|
||||
"llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
|
||||
#if defined(__gfx942__)
|
||||
if (value) {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
|
||||
} else {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
union bf162_int_union {
|
||||
int i;
|
||||
nv_bfloat162 bf2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, int32x4_t* B) {
|
||||
int32x4_t& tR_fragment = A[0];
|
||||
int32x4_t& tA_fragment = B[0];
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(int32x4_t* A, int32x4_t* B) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tA[i] = __hadd2(tA[i], tB[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmax2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_min(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmin2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
|
||||
half2 wmaxh2 = __builtin_bit_cast(half2, a);
|
||||
half2 wminh2 = __builtin_bit_cast(half2, b);
|
||||
half2 wblockmaxh2;
|
||||
|
||||
wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
|
||||
wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
|
||||
return __builtin_bit_cast(int, wblockmaxh2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
|
||||
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_add(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hadd2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_sub(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
|
||||
int result;
|
||||
|
||||
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hsub2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_mul(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
|
||||
nv_bfloat162 tR = __hmul2(*tA, *tB);
|
||||
return *(reinterpret_cast<int*>(&tR));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_rcp(int a);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
|
||||
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
|
||||
bf162_int_union A, R;
|
||||
A.i = a;
|
||||
R.bf2 = h2rcp(A.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
// changes dtype
|
||||
__quickreduce_device_inline__ float T2float_cast(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
|
||||
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
|
||||
|
||||
int wmax, wmin, wblockmax;
|
||||
int a, b;
|
||||
a = packed_max<T>(atom[0], atom[1]);
|
||||
b = packed_max<T>(atom[2], atom[3]);
|
||||
|
||||
wmax = packed_max<T>(a, b);
|
||||
|
||||
a = packed_min<T>(atom[0], atom[1]);
|
||||
b = packed_min<T>(atom[2], atom[3]);
|
||||
|
||||
wmin = packed_min<T>(a, b);
|
||||
|
||||
// Reduce the max among a group of threads
|
||||
// Note: This is basically 2 blocks of values setup as the
|
||||
// upper/lower halves of the f16x2_t
|
||||
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
|
||||
int x = __shfl_down(wmax, i);
|
||||
wmax = packed_max<T>(wmax, x);
|
||||
|
||||
int y = __shfl_down(wmin, i);
|
||||
wmin = packed_min<T>(wmin, y);
|
||||
}
|
||||
wblockmax = packed_abs_max<T>(wmax, wmin);
|
||||
// Share with the cohort
|
||||
wblockmax = __shfl(wblockmax, group_leader);
|
||||
return wblockmax;
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quickreduce
|
||||
@@ -54,6 +54,25 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
|
||||
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
|
||||
|
||||
// quick allreduce
|
||||
#ifdef USE_ROCM
|
||||
m.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
m.def("init_custom_qr", &init_custom_qr);
|
||||
m.def("qr_destroy", &qr_destroy);
|
||||
|
||||
m.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
m.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
m.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
|
||||
@@ -66,6 +66,13 @@ void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
// quick allreduce
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
void qr_destroy(fptr_t _fa);
|
||||
torch::Tensor qr_get_handle(fptr_t _fa);
|
||||
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#else
|
||||
// custom allreduce
|
||||
fptr_t
|
||||
@@ -77,6 +84,8 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
|
||||
// mscclpp
|
||||
torch::Tensor mscclpp_generate_unique_id();
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -49,6 +49,38 @@ if torch.version.hip is not None:
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
# ROCM quick allreduce
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_qr.default(
|
||||
world_size, rank, qr_max_size
|
||||
)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
profile: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cast_bf162half: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.qr_all_reduce.default(
|
||||
fa, profile, inp, out, cast_bf162half
|
||||
)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.qr_destroy.default(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return torch.ops.sgl_kernel.qr_max_size.default()
|
||||
|
||||
# mscclpp
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ include_dirs = [
|
||||
|
||||
sources = [
|
||||
"csrc/allreduce/custom_all_reduce.hip",
|
||||
"csrc/allreduce/quick_all_reduce.cu",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||
"csrc/torch_extension_rocm.cc",
|
||||
|
||||
212
test/srt/test_quick_allreduce.py
Normal file
212
test/srt/test_quick_allreduce.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||
qr_rocm_arch_available,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44) # keep the deterministic seed
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int, cls: Any, test_target: Any, quant_mode: str
|
||||
) -> None:
|
||||
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
# NOTE: We need to set working_dir for distributed tests,
|
||||
# otherwise we may get import errors on ray workers
|
||||
|
||||
ray.init(log_to_driver=True)
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(world_size):
|
||||
refs.append(
|
||||
test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode)
|
||||
)
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class TestQuickAllReduce(CustomTestCase):
|
||||
TEST_SIZES = [
|
||||
2 * 1024 * 1024,
|
||||
4 * 1024 * 1024,
|
||||
8 * 1024 * 1024,
|
||||
16 * 1024 * 1024,
|
||||
32 * 1024 * 1024,
|
||||
]
|
||||
TEST_LOOP = 5
|
||||
# Too many configurations can lead to a test grid that is too large
|
||||
# The tp takes too long to boot,let's just choose 4 out of 12 configurations
|
||||
# WORLD_SIZES = [2, 4, 8]
|
||||
# QUANT_MODE = ["FP", "INT8", "INT6", "INT4"]
|
||||
QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]]
|
||||
|
||||
@unittest.skipIf(
|
||||
not qr_rocm_arch_available(),
|
||||
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
|
||||
)
|
||||
def test_graph_allreduce(self):
|
||||
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
|
||||
quant_mode = quant_mode_world_size_part[0]
|
||||
world_size = quant_mode_world_size_part[1]
|
||||
if world_size > torch.cuda.device_count():
|
||||
continue
|
||||
multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode)
|
||||
|
||||
@unittest.skipIf(
|
||||
not qr_rocm_arch_available(),
|
||||
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
|
||||
)
|
||||
def test_eager_allreduce(self):
|
||||
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
|
||||
quant_mode = quant_mode_world_size_part[0]
|
||||
world_size = quant_mode_world_size_part[1]
|
||||
if world_size > torch.cuda.device_count():
|
||||
continue
|
||||
multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode)
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
|
||||
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=rank,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
# (e.g. NCCL). This will ensure that the communicator is initialized
|
||||
# before any communication happens, so that this group can be used for
|
||||
# graph capture immediately.
|
||||
data = torch.zeros(1)
|
||||
data = data.to(device=device)
|
||||
torch.distributed.all_reduce(data, group=group)
|
||||
torch.cuda.synchronize()
|
||||
del data
|
||||
|
||||
for sz in self.TEST_SIZES:
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
for _ in range(self.TEST_LOOP):
|
||||
with graph_capture() as graph_capture_context:
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(
|
||||
1,
|
||||
23,
|
||||
(sz,),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
inp2 = torch.randint(
|
||||
-23,
|
||||
1,
|
||||
(sz,),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(
|
||||
graph, stream=graph_capture_context.stream
|
||||
):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1, group=group)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
atol = 1.25 * world_size
|
||||
rtol = 0.5 * world_size
|
||||
for inp, out in [[inp1, out1], [inp2, out2]]:
|
||||
torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
|
||||
# try:
|
||||
# torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
|
||||
# except AssertionError as e:
|
||||
# print("Max abs diff:", (out - inp).abs().max())
|
||||
# print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max())
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
|
||||
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=rank,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
for sz in self.TEST_SIZES:
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
for _ in range(self.TEST_LOOP):
|
||||
inp1 = torch.randint(
|
||||
1,
|
||||
23,
|
||||
(sz,),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
dist.all_reduce(inp1, group=group)
|
||||
atol = 1.25 * world_size
|
||||
rtol = 0.5 * world_size
|
||||
torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
|
||||
# try:
|
||||
# torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
|
||||
# except AssertionError as e:
|
||||
# print("Max abs diff:", (out1 - inp1).abs().max())
|
||||
# print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user