[Feat] Support Torch Symm Mem AllReduce (#10571)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-10-06 04:55:19 +08:00
committed by GitHub
parent 148d8d485d
commit 590f2da052
8 changed files with 466 additions and 1 deletions

View File

@@ -0,0 +1,16 @@
MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
10: {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
}

View File

@@ -0,0 +1,164 @@
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
import logging
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
symm_mem_is_available = False
if _is_hip:
symm_mem_is_available = False
if _is_cuda:
symm_mem_is_available = True
class SymmMemCommunicator:
"""
Thin wrapper around symmetric-memory collectives.
This communicator:
- Validates device capability and world size.
- Allocates a shared symmetric buffer.
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
If any prerequisite is not met, the instance remains disabled and will
decline to perform symmetric-memory all-reduce.
"""
# Mapping: compute capability major -> supported world sizes for multimem
# If the current (cc_major, world_size) is not listed, we fall back
# to the two-shot path.
_WORLD_SIZES_MULTIMEM = {
9: [4, 6, 8],
10: [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
"""
Args:
group: Torch process group used for rendezvous and naming.
device: Target CUDA device (index, 'cuda:X', or torch.device).
"""
self.disabled = True
if not symm_mem_available:
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = torch.cuda.get_device_capability(device)[0]
if self.device_capability < 9:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size
]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning(
"SymmMemCommunicator: symmetric memory "
"multicast operations are not supported."
)
self.buffer = None
self.disabled = True
return
self.disabled = False
def should_symm_mem_allreduce(self, inp: torch.Tensor):
"""
Fast-path eligibility check for a given tensor.
Conditions:
- Communicator must be enabled.
- dtype must be bfloat16 (matches kernel + buffer dtype).
- Total byte size must be 4-byte aligned (hardware requirement).
- Payload must be smaller than the symmetric-memory max size.
Returns:
True if the symmetric-memory path can handle this tensor.
"""
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
# enforce 4-byte alignment
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
"""
Perform an in-place sum all-reduce via symmetric memory.
Args:
inp: Input tensor on the target CUDA device (bfloat16).
out: Optional output tensor; if omitted, a new tensor is allocated.
Returns:
The reduced tensor (same shape as inp), or None if disabled.
Implementation details:
- Stages 'inp' into the symmetric buffer.
- Selects 'multimem' or 'two_shot' kernel based on topology.
- Writes the result into 'out' and returns it.
"""
if out is None:
out = torch.empty_like(inp)
self.buffer[: inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
else:
torch.ops.symm_mem.two_shot_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
out.copy_(self.buffer[: inp.numel()].view(out.shape))
return out

View File

@@ -215,12 +215,14 @@ class GroupCoordinator:
use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
symm_mem_comm: Optional[Any] # Symm mem communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
@@ -231,6 +233,7 @@ class GroupCoordinator:
use_pynccl: bool,
use_pymscclpp: bool,
use_custom_allreduce: bool,
use_torch_symm_mem: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_npu_communicator: bool,
@@ -279,6 +282,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator
@@ -294,6 +298,9 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.symm_mem import (
SymmMemCommunicator,
)
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
@@ -342,6 +349,13 @@ class GroupCoordinator:
except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if self.use_torch_symm_mem and self.world_size > 1:
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
# Create communicator for other hardware backends
from sglang.srt.distributed.device_communicators.hpu_communicator import (
HpuCommunicator,
@@ -446,6 +460,7 @@ class GroupCoordinator:
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# TorchSymmMem | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note: When custom quick allreduce is enabled, a runtime check
@@ -547,7 +562,12 @@ class GroupCoordinator:
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
outplace_all_reduce_method = "pymscclpp"
elif (
self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
):
outplace_all_reduce_method = "symm_mem"
if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce(
input_,
@@ -564,6 +584,7 @@ class GroupCoordinator:
ca_comm = self.ca_comm
qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm
symm_mem_comm = self.symm_mem_comm
assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
@@ -571,6 +592,9 @@ class GroupCoordinator:
elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "symm_mem":
assert not symm_mem_comm.disabled
out = symm_mem_comm.all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
@@ -1219,6 +1243,7 @@ def init_world_group(
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
@@ -1234,11 +1259,14 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
use_symm_mem_allreduce: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
if use_symm_mem_allreduce is None:
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
@@ -1246,6 +1274,7 @@ def init_model_parallel_group(
use_pynccl=not _is_npu,
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_npu_communicator=True,
@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False
_ENABLE_SYMM_MEM_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool):
@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
_ENABLE_MSCCLPP_ALL_REDUCE = enable
def set_symm_mem_all_reduce(enable: bool):
global _ENABLE_SYMM_MEM_ALL_REDUCE
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,

View File

@@ -263,6 +263,7 @@ def initialize_dp_attention(
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,

View File

@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
initialize_model_parallel,
set_custom_all_reduce,
set_mscclpp_all_reduce,
set_symm_mem_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.eplb.eplb_manager import EPLBManager
@@ -646,6 +647,7 @@ class ModelRunner:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker:
if self.device == "cpu":

View File

@@ -382,6 +382,7 @@ class ServerArgs:
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
enable_torch_symm_mem: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
@@ -2443,6 +2444,11 @@ class ServerArgs:
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-torch-symm-mem",
action="store_true",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",