[Feat] Support Torch Symm Mem AllReduce (#10571)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
MiB = 1024 * 1024
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
9: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
10: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
},
|
||||
}
|
||||
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
symm_mem_available = True
|
||||
except ImportError:
|
||||
symm_mem_available = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
symm_mem_is_available = False
|
||||
if _is_hip:
|
||||
symm_mem_is_available = False
|
||||
if _is_cuda:
|
||||
symm_mem_is_available = True
|
||||
|
||||
|
||||
class SymmMemCommunicator:
|
||||
"""
|
||||
Thin wrapper around symmetric-memory collectives.
|
||||
|
||||
This communicator:
|
||||
- Validates device capability and world size.
|
||||
- Allocates a shared symmetric buffer.
|
||||
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
||||
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
||||
|
||||
If any prerequisite is not met, the instance remains disabled and will
|
||||
decline to perform symmetric-memory all-reduce.
|
||||
"""
|
||||
|
||||
# Mapping: compute capability major -> supported world sizes for multimem
|
||||
# If the current (cc_major, world_size) is not listed, we fall back
|
||||
# to the two-shot path.
|
||||
_WORLD_SIZES_MULTIMEM = {
|
||||
9: [4, 6, 8],
|
||||
10: [6, 8],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
||||
"""
|
||||
Args:
|
||||
group: Torch process group used for rendezvous and naming.
|
||||
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
||||
"""
|
||||
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
torch.cuda.set_device(device)
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
||||
if self.device_capability < 9:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
"communicator is not available.",
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size
|
||||
]
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported."
|
||||
)
|
||||
self.buffer = None
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
|
||||
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
||||
"""
|
||||
Fast-path eligibility check for a given tensor.
|
||||
|
||||
Conditions:
|
||||
- Communicator must be enabled.
|
||||
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
||||
- Total byte size must be 4-byte aligned (hardware requirement).
|
||||
- Payload must be smaller than the symmetric-memory max size.
|
||||
|
||||
Returns:
|
||||
True if the symmetric-memory path can handle this tensor.
|
||||
"""
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype != self.dtype:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# enforce 4-byte alignment
|
||||
if inp_size % 4 != 0:
|
||||
return False
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Perform an in-place sum all-reduce via symmetric memory.
|
||||
|
||||
Args:
|
||||
inp: Input tensor on the target CUDA device (bfloat16).
|
||||
out: Optional output tensor; if omitted, a new tensor is allocated.
|
||||
|
||||
Returns:
|
||||
The reduced tensor (same shape as inp), or None if disabled.
|
||||
|
||||
Implementation details:
|
||||
- Stages 'inp' into the symmetric buffer.
|
||||
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
||||
- Writes the result into 'out' and returns it.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
||||
return out
|
||||
@@ -215,12 +215,14 @@ class GroupCoordinator:
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
|
||||
use_message_queue_broadcaster: (
|
||||
bool # a hint of whether to use message queue broadcaster
|
||||
)
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
symm_mem_comm: Optional[Any] # Symm mem communicator
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
@@ -231,6 +233,7 @@ class GroupCoordinator:
|
||||
use_pynccl: bool,
|
||||
use_pymscclpp: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_torch_symm_mem: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_npu_communicator: bool,
|
||||
@@ -279,6 +282,7 @@ class GroupCoordinator:
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_npu_communicator = use_npu_communicator
|
||||
@@ -294,6 +298,9 @@ class GroupCoordinator:
|
||||
from sglang.srt.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.symm_mem import (
|
||||
SymmMemCommunicator,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||
@@ -342,6 +349,13 @@ class GroupCoordinator:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if self.use_torch_symm_mem and self.world_size > 1:
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Create communicator for other hardware backends
|
||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||
HpuCommunicator,
|
||||
@@ -446,6 +460,7 @@ class GroupCoordinator:
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# TorchSymmMem | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note: When custom quick allreduce is enabled, a runtime check
|
||||
@@ -547,7 +562,12 @@ class GroupCoordinator:
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "pymscclpp"
|
||||
|
||||
elif (
|
||||
self.symm_mem_comm is not None
|
||||
and not self.symm_mem_comm.disabled
|
||||
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "symm_mem"
|
||||
if outplace_all_reduce_method is not None:
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_,
|
||||
@@ -564,6 +584,7 @@ class GroupCoordinator:
|
||||
ca_comm = self.ca_comm
|
||||
qr_comm = self.qr_comm
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||
if outplace_all_reduce_method == "ca":
|
||||
assert not ca_comm.disabled
|
||||
@@ -571,6 +592,9 @@ class GroupCoordinator:
|
||||
elif outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "symm_mem":
|
||||
assert not symm_mem_comm.disabled
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
@@ -1219,6 +1243,7 @@ def init_world_group(
|
||||
use_pynccl=False,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_torch_symm_mem=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
@@ -1234,11 +1259,14 @@ def init_model_parallel_group(
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
use_symm_mem_allreduce: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
if use_mscclpp_allreduce is None:
|
||||
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
if use_symm_mem_allreduce is None:
|
||||
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
@@ -1246,6 +1274,7 @@ def init_model_parallel_group(
|
||||
use_pynccl=not _is_npu,
|
||||
use_pymscclpp=use_mscclpp_allreduce,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_torch_symm_mem=use_symm_mem_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_npu_communicator=True,
|
||||
@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
||||
_ENABLE_SYMM_MEM_ALL_REDUCE = False
|
||||
|
||||
|
||||
def set_custom_all_reduce(enable: bool):
|
||||
@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def set_symm_mem_all_reduce(enable: bool):
|
||||
global _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
|
||||
@@ -263,6 +263,7 @@ def initialize_dp_attention(
|
||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_torch_symm_mem=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
set_mscclpp_all_reduce,
|
||||
set_symm_mem_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.eplb.eplb_manager import EPLBManager
|
||||
@@ -646,6 +647,7 @@ class ModelRunner:
|
||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
||||
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
||||
|
||||
if not self.is_draft_worker:
|
||||
if self.device == "cpu":
|
||||
|
||||
@@ -382,6 +382,7 @@ class ServerArgs:
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_mscclpp: bool = False
|
||||
enable_torch_symm_mem: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
@@ -2443,6 +2444,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-symm-mem",
|
||||
action="store_true",
|
||||
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-schedule",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user