[Feat] Support Torch Symm Mem AllReduce (#10571)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
234
benchmark/kernels/all_reduce/benchmark_symm_mem.py
Normal file
234
benchmark/kernels/all_reduce/benchmark_symm_mem.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""For Now, SYMM_MEM is only supported on TP8 case
|
||||
|
||||
export WORLD_SIZE=1
|
||||
export RANK=0
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export MASTER_PORT=12345
|
||||
|
||||
torchrun --nproc_per_node gpu \
|
||||
--nnodes $WORLD_SIZE \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_symm_mem_all_reduce,
|
||||
)
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||
dist.all_reduce(torch_input, group=group)
|
||||
return torch_input
|
||||
|
||||
|
||||
def symm_mem_allreduce(
|
||||
symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator
|
||||
) -> torch.Tensor:
|
||||
return symm_mem_comm.all_reduce(symm_mem_input)
|
||||
|
||||
|
||||
def pynccl_allreduce(
|
||||
pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||
) -> torch.Tensor:
|
||||
pynccl_comm.all_reduce(pynccl_input)
|
||||
return pynccl_input
|
||||
|
||||
|
||||
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||
graph_input = inp_randn.clone()
|
||||
with graph_capture() as graph_capture_context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(graph_loop):
|
||||
graph_out = func(graph_input)
|
||||
|
||||
graph.replay()
|
||||
func_output = graph_out.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for _ in range(test_loop):
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||
graph.reset()
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||
eager_input = inp_randn.clone()
|
||||
eager_output = func(eager_input)
|
||||
func_output = eager_output.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
func(eager_input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(test_loop):
|
||||
func(eager_input)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def get_torch_prof_ctx(do_prof: bool):
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
)
|
||||
if do_prof
|
||||
else nullcontext()
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("tabulate not installed, skipping table printing")
|
||||
tabulate = None
|
||||
|
||||
|
||||
def print_markdown_table(data):
|
||||
if tabulate is not None:
|
||||
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||
return
|
||||
headers = data[0].keys()
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
rows = []
|
||||
for item in data:
|
||||
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||
rows.append(row)
|
||||
markdown_table = "\n".join([header_row, separator] + rows)
|
||||
print(markdown_table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank % 8)
|
||||
device = torch.cuda.current_device()
|
||||
set_symm_mem_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=rank % 8,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||
symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm
|
||||
dist.barrier()
|
||||
profile = False
|
||||
dtype = torch.bfloat16
|
||||
ctx = get_torch_prof_ctx(profile)
|
||||
result = []
|
||||
|
||||
with ctx:
|
||||
if IS_CI:
|
||||
i_range = range(10, 11)
|
||||
else:
|
||||
i_range = range(10, 20)
|
||||
for i in i_range:
|
||||
sz = 2**i
|
||||
if sz * dtype.itemsize > 2**24:
|
||||
break
|
||||
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
|
||||
memory = torch.empty_like(inp_randn)
|
||||
memory_out = torch.empty_like(memory)
|
||||
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||
)
|
||||
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
|
||||
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
|
||||
)
|
||||
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
|
||||
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
|
||||
)
|
||||
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||
_, pynccl_graph_time = _bench_graph_time(
|
||||
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||
)
|
||||
torch.testing.assert_close(torch_eager_output, symm_mem_graph_output)
|
||||
torch.testing.assert_close(torch_eager_output, symm_mem_eager_output)
|
||||
result.append(
|
||||
{
|
||||
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||
"torch eager time": torch_eager_time,
|
||||
"symm mem eager time": symm_mem_eager_time,
|
||||
"symm mem graph time": symm_mem_graph_time,
|
||||
"pynccl graph time": pynccl_graph_time,
|
||||
}
|
||||
)
|
||||
if rank == 0:
|
||||
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||
if rank == 0:
|
||||
print_markdown_table(result)
|
||||
if profile:
|
||||
prof_dir = f"prof/symm_mem"
|
||||
os.makedirs(prof_dir, exist_ok=True)
|
||||
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||
@@ -28,6 +28,8 @@ def launch_server(args):
|
||||
cmd += "--disable-custom-all-reduce"
|
||||
if args.enable_mscclpp:
|
||||
cmd += "--enable-mscclpp"
|
||||
if args.enable_torch_symm_mem:
|
||||
cmd += "--enable-torch-symm-mem"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
@@ -70,6 +72,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-symm-mem",
|
||||
action="store_true",
|
||||
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
launch_server(args)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
MiB = 1024 * 1024
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
9: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
10: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
},
|
||||
}
|
||||
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
symm_mem_available = True
|
||||
except ImportError:
|
||||
symm_mem_available = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
symm_mem_is_available = False
|
||||
if _is_hip:
|
||||
symm_mem_is_available = False
|
||||
if _is_cuda:
|
||||
symm_mem_is_available = True
|
||||
|
||||
|
||||
class SymmMemCommunicator:
|
||||
"""
|
||||
Thin wrapper around symmetric-memory collectives.
|
||||
|
||||
This communicator:
|
||||
- Validates device capability and world size.
|
||||
- Allocates a shared symmetric buffer.
|
||||
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
||||
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
||||
|
||||
If any prerequisite is not met, the instance remains disabled and will
|
||||
decline to perform symmetric-memory all-reduce.
|
||||
"""
|
||||
|
||||
# Mapping: compute capability major -> supported world sizes for multimem
|
||||
# If the current (cc_major, world_size) is not listed, we fall back
|
||||
# to the two-shot path.
|
||||
_WORLD_SIZES_MULTIMEM = {
|
||||
9: [4, 6, 8],
|
||||
10: [6, 8],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
||||
"""
|
||||
Args:
|
||||
group: Torch process group used for rendezvous and naming.
|
||||
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
||||
"""
|
||||
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
torch.cuda.set_device(device)
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
||||
if self.device_capability < 9:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
"communicator is not available.",
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size
|
||||
]
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported."
|
||||
)
|
||||
self.buffer = None
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
|
||||
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
||||
"""
|
||||
Fast-path eligibility check for a given tensor.
|
||||
|
||||
Conditions:
|
||||
- Communicator must be enabled.
|
||||
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
||||
- Total byte size must be 4-byte aligned (hardware requirement).
|
||||
- Payload must be smaller than the symmetric-memory max size.
|
||||
|
||||
Returns:
|
||||
True if the symmetric-memory path can handle this tensor.
|
||||
"""
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype != self.dtype:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# enforce 4-byte alignment
|
||||
if inp_size % 4 != 0:
|
||||
return False
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Perform an in-place sum all-reduce via symmetric memory.
|
||||
|
||||
Args:
|
||||
inp: Input tensor on the target CUDA device (bfloat16).
|
||||
out: Optional output tensor; if omitted, a new tensor is allocated.
|
||||
|
||||
Returns:
|
||||
The reduced tensor (same shape as inp), or None if disabled.
|
||||
|
||||
Implementation details:
|
||||
- Stages 'inp' into the symmetric buffer.
|
||||
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
||||
- Writes the result into 'out' and returns it.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
||||
return out
|
||||
@@ -215,12 +215,14 @@ class GroupCoordinator:
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
|
||||
use_message_queue_broadcaster: (
|
||||
bool # a hint of whether to use message queue broadcaster
|
||||
)
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
symm_mem_comm: Optional[Any] # Symm mem communicator
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
@@ -231,6 +233,7 @@ class GroupCoordinator:
|
||||
use_pynccl: bool,
|
||||
use_pymscclpp: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_torch_symm_mem: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_npu_communicator: bool,
|
||||
@@ -279,6 +282,7 @@ class GroupCoordinator:
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_npu_communicator = use_npu_communicator
|
||||
@@ -294,6 +298,9 @@ class GroupCoordinator:
|
||||
from sglang.srt.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.symm_mem import (
|
||||
SymmMemCommunicator,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||
@@ -342,6 +349,13 @@ class GroupCoordinator:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if self.use_torch_symm_mem and self.world_size > 1:
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Create communicator for other hardware backends
|
||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||
HpuCommunicator,
|
||||
@@ -446,6 +460,7 @@ class GroupCoordinator:
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# TorchSymmMem | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note: When custom quick allreduce is enabled, a runtime check
|
||||
@@ -547,7 +562,12 @@ class GroupCoordinator:
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "pymscclpp"
|
||||
|
||||
elif (
|
||||
self.symm_mem_comm is not None
|
||||
and not self.symm_mem_comm.disabled
|
||||
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
|
||||
):
|
||||
outplace_all_reduce_method = "symm_mem"
|
||||
if outplace_all_reduce_method is not None:
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_,
|
||||
@@ -564,6 +584,7 @@ class GroupCoordinator:
|
||||
ca_comm = self.ca_comm
|
||||
qr_comm = self.qr_comm
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||
if outplace_all_reduce_method == "ca":
|
||||
assert not ca_comm.disabled
|
||||
@@ -571,6 +592,9 @@ class GroupCoordinator:
|
||||
elif outplace_all_reduce_method == "qr":
|
||||
assert not qr_comm.disabled
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "symm_mem":
|
||||
assert not symm_mem_comm.disabled
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
@@ -1219,6 +1243,7 @@ def init_world_group(
|
||||
use_pynccl=False,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_torch_symm_mem=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
@@ -1234,11 +1259,14 @@ def init_model_parallel_group(
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
use_symm_mem_allreduce: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
if use_mscclpp_allreduce is None:
|
||||
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
if use_symm_mem_allreduce is None:
|
||||
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
@@ -1246,6 +1274,7 @@ def init_model_parallel_group(
|
||||
use_pynccl=not _is_npu,
|
||||
use_pymscclpp=use_mscclpp_allreduce,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_torch_symm_mem=use_symm_mem_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_npu_communicator=True,
|
||||
@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
||||
_ENABLE_SYMM_MEM_ALL_REDUCE = False
|
||||
|
||||
|
||||
def set_custom_all_reduce(enable: bool):
|
||||
@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def set_symm_mem_all_reduce(enable: bool):
|
||||
global _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
|
||||
@@ -263,6 +263,7 @@ def initialize_dp_attention(
|
||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_torch_symm_mem=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
set_mscclpp_all_reduce,
|
||||
set_symm_mem_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.eplb.eplb_manager import EPLBManager
|
||||
@@ -646,6 +647,7 @@ class ModelRunner:
|
||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
||||
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
||||
|
||||
if not self.is_draft_worker:
|
||||
if self.device == "cpu":
|
||||
|
||||
@@ -382,6 +382,7 @@ class ServerArgs:
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_mscclpp: bool = False
|
||||
enable_torch_symm_mem: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
@@ -2443,6 +2444,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-symm-mem",
|
||||
action="store_true",
|
||||
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-schedule",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user