[Feat] Support Torch Symm Mem AllReduce (#10571)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
234
benchmark/kernels/all_reduce/benchmark_symm_mem.py
Normal file
234
benchmark/kernels/all_reduce/benchmark_symm_mem.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""For Now, SYMM_MEM is only supported on TP8 case
|
||||||
|
|
||||||
|
export WORLD_SIZE=1
|
||||||
|
export RANK=0
|
||||||
|
export MASTER_ADDR=127.0.0.1
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
|
||||||
|
torchrun --nproc_per_node gpu \
|
||||||
|
--nnodes $WORLD_SIZE \
|
||||||
|
--node_rank $RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.distributed import init_distributed_environment
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||||
|
from sglang.srt.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_group,
|
||||||
|
graph_capture,
|
||||||
|
initialize_model_parallel,
|
||||||
|
set_symm_mem_all_reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CI environment detection
|
||||||
|
IS_CI = (
|
||||||
|
os.getenv("CI", "false").lower() == "true"
|
||||||
|
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||||
|
dist.all_reduce(torch_input, group=group)
|
||||||
|
return torch_input
|
||||||
|
|
||||||
|
|
||||||
|
def symm_mem_allreduce(
|
||||||
|
symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return symm_mem_comm.all_reduce(symm_mem_input)
|
||||||
|
|
||||||
|
|
||||||
|
def pynccl_allreduce(
|
||||||
|
pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pynccl_comm.all_reduce(pynccl_input)
|
||||||
|
return pynccl_input
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||||
|
graph_input = inp_randn.clone()
|
||||||
|
with graph_capture() as graph_capture_context:
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||||
|
for _ in range(graph_loop):
|
||||||
|
graph_out = func(graph_input)
|
||||||
|
|
||||||
|
graph.replay()
|
||||||
|
func_output = graph_out.clone()
|
||||||
|
|
||||||
|
for _ in range(warmup_loop):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: List[float] = []
|
||||||
|
for _ in range(test_loop):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
dist.barrier()
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||||
|
graph.reset()
|
||||||
|
return func_output, func_cost_us
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||||
|
eager_input = inp_randn.clone()
|
||||||
|
eager_output = func(eager_input)
|
||||||
|
func_output = eager_output.clone()
|
||||||
|
|
||||||
|
for _ in range(warmup_loop):
|
||||||
|
func(eager_input)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
for _ in range(test_loop):
|
||||||
|
func(eager_input)
|
||||||
|
end_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||||
|
|
||||||
|
return func_output, func_cost_us
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_prof_ctx(do_prof: bool):
|
||||||
|
ctx = (
|
||||||
|
torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
if do_prof
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
def human_readable_size(size, decimal_places=1):
|
||||||
|
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||||
|
if size < 1024.0 or unit == "PiB":
|
||||||
|
break
|
||||||
|
size /= 1024.0
|
||||||
|
return f"{size:.{decimal_places}f} {unit}"
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
print("tabulate not installed, skipping table printing")
|
||||||
|
tabulate = None
|
||||||
|
|
||||||
|
|
||||||
|
def print_markdown_table(data):
|
||||||
|
if tabulate is not None:
|
||||||
|
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||||
|
return
|
||||||
|
headers = data[0].keys()
|
||||||
|
header_row = "| " + " | ".join(headers) + " |"
|
||||||
|
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||||
|
rows = []
|
||||||
|
for item in data:
|
||||||
|
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||||
|
rows.append(row)
|
||||||
|
markdown_table = "\n".join([header_row, separator] + rows)
|
||||||
|
print(markdown_table)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
torch.cuda.set_device(rank % 8)
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
set_symm_mem_all_reduce(True)
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
local_rank=rank % 8,
|
||||||
|
)
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
group = get_tensor_model_parallel_group().device_group
|
||||||
|
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||||
|
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||||
|
symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm
|
||||||
|
dist.barrier()
|
||||||
|
profile = False
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
ctx = get_torch_prof_ctx(profile)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
if IS_CI:
|
||||||
|
i_range = range(10, 11)
|
||||||
|
else:
|
||||||
|
i_range = range(10, 20)
|
||||||
|
for i in i_range:
|
||||||
|
sz = 2**i
|
||||||
|
if sz * dtype.itemsize > 2**24:
|
||||||
|
break
|
||||||
|
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
memory = torch.empty_like(inp_randn)
|
||||||
|
memory_out = torch.empty_like(memory)
|
||||||
|
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||||
|
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||||
|
)
|
||||||
|
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
|
||||||
|
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
|
||||||
|
)
|
||||||
|
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
|
||||||
|
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
|
||||||
|
)
|
||||||
|
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||||
|
_, pynccl_graph_time = _bench_graph_time(
|
||||||
|
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(torch_eager_output, symm_mem_graph_output)
|
||||||
|
torch.testing.assert_close(torch_eager_output, symm_mem_eager_output)
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||||
|
"torch eager time": torch_eager_time,
|
||||||
|
"symm mem eager time": symm_mem_eager_time,
|
||||||
|
"symm mem graph time": symm_mem_graph_time,
|
||||||
|
"pynccl graph time": pynccl_graph_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||||
|
if rank == 0:
|
||||||
|
print_markdown_table(result)
|
||||||
|
if profile:
|
||||||
|
prof_dir = f"prof/symm_mem"
|
||||||
|
os.makedirs(prof_dir, exist_ok=True)
|
||||||
|
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||||
@@ -28,6 +28,8 @@ def launch_server(args):
|
|||||||
cmd += "--disable-custom-all-reduce"
|
cmd += "--disable-custom-all-reduce"
|
||||||
if args.enable_mscclpp:
|
if args.enable_mscclpp:
|
||||||
cmd += "--enable-mscclpp"
|
cmd += "--enable-mscclpp"
|
||||||
|
if args.enable_torch_symm_mem:
|
||||||
|
cmd += "--enable-torch-symm-mem"
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
@@ -70,6 +72,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-torch-symm-mem",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
launch_server(args)
|
launch_server(args)
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
MiB = 1024 * 1024
|
||||||
|
|
||||||
|
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||||
|
9: {
|
||||||
|
2: 64 * MiB, # 64 MB
|
||||||
|
4: 32 * MiB, # 32 MB
|
||||||
|
6: 64 * MiB, # 64 MB
|
||||||
|
8: 64 * MiB, # 64 MB
|
||||||
|
},
|
||||||
|
10: {
|
||||||
|
2: 64 * MiB, # 64 MB
|
||||||
|
4: 32 * MiB, # 32 MB
|
||||||
|
6: 128 * MiB, # 128 MB
|
||||||
|
8: 128 * MiB, # 128 MB
|
||||||
|
},
|
||||||
|
}
|
||||||
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
164
python/sglang/srt/distributed/device_communicators/symm_mem.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
||||||
|
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||||
|
|
||||||
|
symm_mem_available = True
|
||||||
|
except ImportError:
|
||||||
|
symm_mem_available = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
symm_mem_is_available = False
|
||||||
|
if _is_hip:
|
||||||
|
symm_mem_is_available = False
|
||||||
|
if _is_cuda:
|
||||||
|
symm_mem_is_available = True
|
||||||
|
|
||||||
|
|
||||||
|
class SymmMemCommunicator:
|
||||||
|
"""
|
||||||
|
Thin wrapper around symmetric-memory collectives.
|
||||||
|
|
||||||
|
This communicator:
|
||||||
|
- Validates device capability and world size.
|
||||||
|
- Allocates a shared symmetric buffer.
|
||||||
|
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
||||||
|
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
||||||
|
|
||||||
|
If any prerequisite is not met, the instance remains disabled and will
|
||||||
|
decline to perform symmetric-memory all-reduce.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mapping: compute capability major -> supported world sizes for multimem
|
||||||
|
# If the current (cc_major, world_size) is not listed, we fall back
|
||||||
|
# to the two-shot path.
|
||||||
|
_WORLD_SIZES_MULTIMEM = {
|
||||||
|
9: [4, 6, 8],
|
||||||
|
10: [6, 8],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
group: Torch process group used for rendezvous and naming.
|
||||||
|
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
|
if not symm_mem_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(device, int):
|
||||||
|
device = torch.device(f"cuda:{device}")
|
||||||
|
elif isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
self.device = device
|
||||||
|
self.group = group
|
||||||
|
self.world_size = dist.get_world_size(self.group)
|
||||||
|
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
||||||
|
if self.device_capability < 9:
|
||||||
|
logger.warning(
|
||||||
|
"SymmMemCommunicator: Device capability %s not supported, "
|
||||||
|
"communicator is not available.",
|
||||||
|
self.device_capability,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
||||||
|
logger.warning(
|
||||||
|
"SymmMemCommunicator: World size %d not supported, "
|
||||||
|
"communicator is not available.",
|
||||||
|
self.world_size,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||||
|
self.world_size
|
||||||
|
]
|
||||||
|
self.buffer = torch_symm_mem.empty(
|
||||||
|
self.max_size // self.dtype.itemsize,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||||
|
if handle.multicast_ptr == 0:
|
||||||
|
logger.warning(
|
||||||
|
"SymmMemCommunicator: symmetric memory "
|
||||||
|
"multicast operations are not supported."
|
||||||
|
)
|
||||||
|
self.buffer = None
|
||||||
|
self.disabled = True
|
||||||
|
return
|
||||||
|
self.disabled = False
|
||||||
|
|
||||||
|
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Fast-path eligibility check for a given tensor.
|
||||||
|
|
||||||
|
Conditions:
|
||||||
|
- Communicator must be enabled.
|
||||||
|
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
||||||
|
- Total byte size must be 4-byte aligned (hardware requirement).
|
||||||
|
- Payload must be smaller than the symmetric-memory max size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the symmetric-memory path can handle this tensor.
|
||||||
|
"""
|
||||||
|
if self.disabled:
|
||||||
|
return False
|
||||||
|
if inp.dtype != self.dtype:
|
||||||
|
return False
|
||||||
|
inp_size = inp.numel() * inp.element_size()
|
||||||
|
# enforce 4-byte alignment
|
||||||
|
if inp_size % 4 != 0:
|
||||||
|
return False
|
||||||
|
return inp_size < self.max_size
|
||||||
|
|
||||||
|
def all_reduce(
|
||||||
|
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Perform an in-place sum all-reduce via symmetric memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inp: Input tensor on the target CUDA device (bfloat16).
|
||||||
|
out: Optional output tensor; if omitted, a new tensor is allocated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reduced tensor (same shape as inp), or None if disabled.
|
||||||
|
|
||||||
|
Implementation details:
|
||||||
|
- Stages 'inp' into the symmetric buffer.
|
||||||
|
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
||||||
|
- Writes the result into 'out' and returns it.
|
||||||
|
"""
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty_like(inp)
|
||||||
|
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
||||||
|
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
||||||
|
torch.ops.symm_mem.multimem_all_reduce_(
|
||||||
|
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.ops.symm_mem.two_shot_all_reduce_(
|
||||||
|
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||||
|
)
|
||||||
|
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
||||||
|
return out
|
||||||
@@ -215,12 +215,14 @@ class GroupCoordinator:
|
|||||||
use_pynccl: bool # a hint of whether to use PyNccl
|
use_pynccl: bool # a hint of whether to use PyNccl
|
||||||
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
||||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||||
|
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
|
||||||
use_message_queue_broadcaster: (
|
use_message_queue_broadcaster: (
|
||||||
bool # a hint of whether to use message queue broadcaster
|
bool # a hint of whether to use message queue broadcaster
|
||||||
)
|
)
|
||||||
# communicators are only created for world size > 1
|
# communicators are only created for world size > 1
|
||||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||||
|
symm_mem_comm: Optional[Any] # Symm mem communicator
|
||||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -231,6 +233,7 @@ class GroupCoordinator:
|
|||||||
use_pynccl: bool,
|
use_pynccl: bool,
|
||||||
use_pymscclpp: bool,
|
use_pymscclpp: bool,
|
||||||
use_custom_allreduce: bool,
|
use_custom_allreduce: bool,
|
||||||
|
use_torch_symm_mem: bool,
|
||||||
use_hpu_communicator: bool,
|
use_hpu_communicator: bool,
|
||||||
use_xpu_communicator: bool,
|
use_xpu_communicator: bool,
|
||||||
use_npu_communicator: bool,
|
use_npu_communicator: bool,
|
||||||
@@ -279,6 +282,7 @@ class GroupCoordinator:
|
|||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
self.use_pymscclpp = use_pymscclpp
|
self.use_pymscclpp = use_pymscclpp
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
self.use_torch_symm_mem = use_torch_symm_mem
|
||||||
self.use_hpu_communicator = use_hpu_communicator
|
self.use_hpu_communicator = use_hpu_communicator
|
||||||
self.use_xpu_communicator = use_xpu_communicator
|
self.use_xpu_communicator = use_xpu_communicator
|
||||||
self.use_npu_communicator = use_npu_communicator
|
self.use_npu_communicator = use_npu_communicator
|
||||||
@@ -294,6 +298,9 @@ class GroupCoordinator:
|
|||||||
from sglang.srt.distributed.device_communicators.pynccl import (
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
||||||
PyNcclCommunicator,
|
PyNcclCommunicator,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.symm_mem import (
|
||||||
|
SymmMemCommunicator,
|
||||||
|
)
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
||||||
@@ -342,6 +349,13 @@ class GroupCoordinator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
||||||
|
|
||||||
|
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||||
|
if self.use_torch_symm_mem and self.world_size > 1:
|
||||||
|
self.symm_mem_comm = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Create communicator for other hardware backends
|
# Create communicator for other hardware backends
|
||||||
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
||||||
HpuCommunicator,
|
HpuCommunicator,
|
||||||
@@ -446,6 +460,7 @@ class GroupCoordinator:
|
|||||||
# custom allreduce | enabled | enabled |
|
# custom allreduce | enabled | enabled |
|
||||||
# PyNccl | disabled| enabled |
|
# PyNccl | disabled| enabled |
|
||||||
# PyMscclpp | disabled| enabled |
|
# PyMscclpp | disabled| enabled |
|
||||||
|
# TorchSymmMem | disabled| enabled |
|
||||||
# torch.distributed | enabled | disabled|
|
# torch.distributed | enabled | disabled|
|
||||||
#
|
#
|
||||||
# Note: When custom quick allreduce is enabled, a runtime check
|
# Note: When custom quick allreduce is enabled, a runtime check
|
||||||
@@ -547,7 +562,12 @@ class GroupCoordinator:
|
|||||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||||
):
|
):
|
||||||
outplace_all_reduce_method = "pymscclpp"
|
outplace_all_reduce_method = "pymscclpp"
|
||||||
|
elif (
|
||||||
|
self.symm_mem_comm is not None
|
||||||
|
and not self.symm_mem_comm.disabled
|
||||||
|
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
|
||||||
|
):
|
||||||
|
outplace_all_reduce_method = "symm_mem"
|
||||||
if outplace_all_reduce_method is not None:
|
if outplace_all_reduce_method is not None:
|
||||||
return torch.ops.sglang.outplace_all_reduce(
|
return torch.ops.sglang.outplace_all_reduce(
|
||||||
input_,
|
input_,
|
||||||
@@ -564,6 +584,7 @@ class GroupCoordinator:
|
|||||||
ca_comm = self.ca_comm
|
ca_comm = self.ca_comm
|
||||||
qr_comm = self.qr_comm
|
qr_comm = self.qr_comm
|
||||||
pymscclpp_comm = self.pymscclpp_comm
|
pymscclpp_comm = self.pymscclpp_comm
|
||||||
|
symm_mem_comm = self.symm_mem_comm
|
||||||
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
||||||
if outplace_all_reduce_method == "ca":
|
if outplace_all_reduce_method == "ca":
|
||||||
assert not ca_comm.disabled
|
assert not ca_comm.disabled
|
||||||
@@ -571,6 +592,9 @@ class GroupCoordinator:
|
|||||||
elif outplace_all_reduce_method == "qr":
|
elif outplace_all_reduce_method == "qr":
|
||||||
assert not qr_comm.disabled
|
assert not qr_comm.disabled
|
||||||
out = qr_comm.quick_all_reduce(input_)
|
out = qr_comm.quick_all_reduce(input_)
|
||||||
|
elif outplace_all_reduce_method == "symm_mem":
|
||||||
|
assert not symm_mem_comm.disabled
|
||||||
|
out = symm_mem_comm.all_reduce(input_)
|
||||||
else:
|
else:
|
||||||
assert not pymscclpp_comm.disabled
|
assert not pymscclpp_comm.disabled
|
||||||
out = pymscclpp_comm.all_reduce(input_)
|
out = pymscclpp_comm.all_reduce(input_)
|
||||||
@@ -1219,6 +1243,7 @@ def init_world_group(
|
|||||||
use_pynccl=False,
|
use_pynccl=False,
|
||||||
use_pymscclpp=False,
|
use_pymscclpp=False,
|
||||||
use_custom_allreduce=False,
|
use_custom_allreduce=False,
|
||||||
|
use_torch_symm_mem=False,
|
||||||
use_hpu_communicator=False,
|
use_hpu_communicator=False,
|
||||||
use_xpu_communicator=False,
|
use_xpu_communicator=False,
|
||||||
use_npu_communicator=False,
|
use_npu_communicator=False,
|
||||||
@@ -1234,11 +1259,14 @@ def init_model_parallel_group(
|
|||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
use_mscclpp_allreduce: Optional[bool] = None,
|
use_mscclpp_allreduce: Optional[bool] = None,
|
||||||
|
use_symm_mem_allreduce: Optional[bool] = None,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
if use_custom_allreduce is None:
|
if use_custom_allreduce is None:
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
if use_mscclpp_allreduce is None:
|
if use_mscclpp_allreduce is None:
|
||||||
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
||||||
|
if use_symm_mem_allreduce is None:
|
||||||
|
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||||
return GroupCoordinator(
|
return GroupCoordinator(
|
||||||
group_ranks=group_ranks,
|
group_ranks=group_ranks,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
@@ -1246,6 +1274,7 @@ def init_model_parallel_group(
|
|||||||
use_pynccl=not _is_npu,
|
use_pynccl=not _is_npu,
|
||||||
use_pymscclpp=use_mscclpp_allreduce,
|
use_pymscclpp=use_mscclpp_allreduce,
|
||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
|
use_torch_symm_mem=use_symm_mem_allreduce,
|
||||||
use_hpu_communicator=True,
|
use_hpu_communicator=True,
|
||||||
use_xpu_communicator=True,
|
use_xpu_communicator=True,
|
||||||
use_npu_communicator=True,
|
use_npu_communicator=True,
|
||||||
@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||||
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
||||||
|
_ENABLE_SYMM_MEM_ALL_REDUCE = False
|
||||||
|
|
||||||
|
|
||||||
def set_custom_all_reduce(enable: bool):
|
def set_custom_all_reduce(enable: bool):
|
||||||
@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
|
|||||||
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
||||||
|
|
||||||
|
|
||||||
|
def set_symm_mem_all_reduce(enable: bool):
|
||||||
|
global _ENABLE_SYMM_MEM_ALL_REDUCE
|
||||||
|
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(
|
def init_distributed_environment(
|
||||||
world_size: int = -1,
|
world_size: int = -1,
|
||||||
rank: int = -1,
|
rank: int = -1,
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ def initialize_dp_attention(
|
|||||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||||
use_pymscclpp=False,
|
use_pymscclpp=False,
|
||||||
use_custom_allreduce=False,
|
use_custom_allreduce=False,
|
||||||
|
use_torch_symm_mem=False,
|
||||||
use_hpu_communicator=False,
|
use_hpu_communicator=False,
|
||||||
use_xpu_communicator=False,
|
use_xpu_communicator=False,
|
||||||
use_npu_communicator=False,
|
use_npu_communicator=False,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
|||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
set_mscclpp_all_reduce,
|
set_mscclpp_all_reduce,
|
||||||
|
set_symm_mem_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||||
from sglang.srt.eplb.eplb_manager import EPLBManager
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
||||||
@@ -646,6 +647,7 @@ class ModelRunner:
|
|||||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||||
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
||||||
|
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
||||||
|
|
||||||
if not self.is_draft_worker:
|
if not self.is_draft_worker:
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
|
|||||||
@@ -382,6 +382,7 @@ class ServerArgs:
|
|||||||
disable_outlines_disk_cache: bool = False
|
disable_outlines_disk_cache: bool = False
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
enable_mscclpp: bool = False
|
enable_mscclpp: bool = False
|
||||||
|
enable_torch_symm_mem: bool = False
|
||||||
disable_overlap_schedule: bool = False
|
disable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
@@ -2443,6 +2444,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-torch-symm-mem",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-overlap-schedule",
|
"--disable-overlap-schedule",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user