support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
224
benchmark/kernels/all_reduce/benchmark_mscclpp.py
Normal file
224
benchmark/kernels/all_reduce/benchmark_mscclpp.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""For Now, MSCCL is only supported on TP16 and TP8 case
|
||||
|
||||
export WORLD_SIZE=1
|
||||
export RANK=0
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export MASTER_PORT=12345
|
||||
|
||||
torchrun --nproc_per_node gpu \
|
||||
--nnodes $WORLD_SIZE \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_mscclpp_all_reduce,
|
||||
)
|
||||
|
||||
|
||||
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||
dist.all_reduce(torch_input, group=group)
|
||||
return torch_input
|
||||
|
||||
|
||||
def msccl_allreduce(
|
||||
msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator
|
||||
) -> torch.Tensor:
|
||||
return msccl_comm.all_reduce(msccl_input)
|
||||
|
||||
|
||||
def pynccl_allreduce(
|
||||
msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||
) -> torch.Tensor:
|
||||
pynccl_comm.all_reduce(msccl_input)
|
||||
return msccl_input
|
||||
|
||||
|
||||
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||
graph_input = inp_randn.clone()
|
||||
with graph_capture() as graph_capture_context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(graph_loop):
|
||||
graph_out = func(graph_input)
|
||||
|
||||
graph.replay()
|
||||
func_output = graph_out.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for _ in range(test_loop):
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||
graph.reset()
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||
eager_input = inp_randn.clone()
|
||||
eager_output = func(eager_input)
|
||||
func_output = eager_output.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
func(eager_input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(test_loop):
|
||||
func(eager_input)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def get_torch_prof_ctx(do_prof: bool):
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
)
|
||||
if do_prof
|
||||
else nullcontext()
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("tabulate not installed, skipping table printing")
|
||||
tabulate = None
|
||||
|
||||
|
||||
def print_markdown_table(data):
|
||||
if tabulate is not None:
|
||||
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||
return
|
||||
headers = data[0].keys()
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
rows = []
|
||||
for item in data:
|
||||
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||
rows.append(row)
|
||||
markdown_table = "\n".join([header_row, separator] + rows)
|
||||
print(markdown_table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank % 8)
|
||||
device = torch.cuda.current_device()
|
||||
set_mscclpp_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=rank % 8,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||
pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm
|
||||
dist.barrier()
|
||||
profile = False
|
||||
dtype = torch.bfloat16
|
||||
ctx = get_torch_prof_ctx(profile)
|
||||
result = []
|
||||
|
||||
with ctx:
|
||||
for i in range(10, 20):
|
||||
sz = 2**i
|
||||
if sz * dtype.itemsize > 2**20:
|
||||
break
|
||||
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
|
||||
memory = torch.empty_like(inp_randn)
|
||||
memory_out = torch.empty_like(memory)
|
||||
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||
)
|
||||
msccl_eager_output, msccl_eager_time = _bench_eager_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
msccl_graph_output, msccl_graph_time = _bench_graph_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||
_, pynccl_graph_time = _bench_graph_time(
|
||||
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||
)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_graph_output)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_eager_output)
|
||||
result.append(
|
||||
{
|
||||
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||
"torch eager time": torch_eager_time,
|
||||
"msccl eager time": msccl_eager_time,
|
||||
"msccl graph time": msccl_graph_time,
|
||||
"pynccl graph time": pynccl_graph_time,
|
||||
}
|
||||
)
|
||||
if rank == 0:
|
||||
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||
if rank == 0:
|
||||
print_markdown_table(result)
|
||||
if profile:
|
||||
prof_dir = f"prof/msccl"
|
||||
os.makedirs(prof_dir, exist_ok=True)
|
||||
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||
@@ -26,6 +26,8 @@ def launch_server(args):
|
||||
cmd += f"--tp-size {args.tp_size} "
|
||||
if args.disable_custom_all_reduce:
|
||||
cmd += "--disable-custom-all-reduce"
|
||||
if args.enable_mscclpp:
|
||||
cmd += "--enable-mscclpp"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
@@ -63,6 +65,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable custom all reduce when device does not support p2p communication",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mscclpp",
|
||||
action="store_true",
|
||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
launch_server(args)
|
||||
|
||||
@@ -201,6 +201,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
||||
| `disable_cuda_graph_padding` | Disable CUDA Graph when padding is needed; otherwise, still use CUDA Graph. | `False` |
|
||||
| `disable_outlines_disk_cache` | Disable disk cache for outlines grammar backend. | `False` |
|
||||
| `disable_custom_all_reduce` | Disable usage of custom all-reduce kernel. | `False` |
|
||||
| `enable_mscclpp` | Enable usage of mscclpp kernel for small message all-reduce. | `False` |
|
||||
| `disable_overlap_schedule` | Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). | `False` |
|
||||
| `enable_nan_detection` | Enable warning if the logits contain `NaN`. | `False` |
|
||||
| `enable_p2p_check` | Turns off the default of always allowing P2P checks when accessing GPU. | `False` |
|
||||
|
||||
@@ -113,3 +113,37 @@ else:
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
||||
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||
|
||||
315
python/sglang/srt/distributed/device_communicators/pymscclpp.py
Normal file
315
python/sglang/srt/distributed/device_communicators/pymscclpp.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import bisect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
mscclpp_is_available = False
|
||||
if _is_hip:
|
||||
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
|
||||
mscclpp_is_available = False
|
||||
if _is_cuda:
|
||||
try:
|
||||
import sgl_kernel
|
||||
|
||||
mscclpp_is_available = True
|
||||
except:
|
||||
mscclpp_is_available = False
|
||||
|
||||
|
||||
class MscclContextSelection(IntEnum):
|
||||
MSCCL1SHOT1NODELL = 1
|
||||
MSCCL1SHOT2NODELL = 2
|
||||
|
||||
|
||||
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
def mscclpp_convert_to_bytes(size_str):
|
||||
"""
|
||||
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
|
||||
into the equivalent number of bytes using binary units.
|
||||
|
||||
Args:
|
||||
size_str (str): A string representing size with unit (KB, MB, GB).
|
||||
|
||||
Returns:
|
||||
int: Number of bytes.
|
||||
"""
|
||||
size_str = size_str.strip().lower()
|
||||
|
||||
if not size_str:
|
||||
raise ValueError("Empty input string")
|
||||
|
||||
# Extract numeric part and unit
|
||||
for i in range(len(size_str)):
|
||||
if not size_str[i].isdigit() and size_str[i] != ".":
|
||||
break
|
||||
num_str = size_str[:i]
|
||||
unit = size_str[i:].strip()
|
||||
|
||||
try:
|
||||
num = float(num_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid numeric value in '{size_str}'")
|
||||
|
||||
# Conversion factors
|
||||
if unit == "b":
|
||||
return int(num)
|
||||
elif unit == "kb":
|
||||
return int(num * 1024)
|
||||
elif unit == "mb":
|
||||
return int(num * 1024 * 1024)
|
||||
elif unit == "gb":
|
||||
return int(num * 1024 * 1024 * 1024)
|
||||
else:
|
||||
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
|
||||
|
||||
|
||||
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
|
||||
# warmup
|
||||
for _ in range(warmup_niter):
|
||||
func()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
for _ in range(test_niter):
|
||||
func()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
|
||||
return func_cost_us
|
||||
|
||||
|
||||
class PyMscclppCommunicator:
|
||||
_SUPPORTED_WORLD_SIZES = [8, 16]
|
||||
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
|
||||
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
|
||||
|
||||
# max_bytes: max supported mscclpp allreduce size
|
||||
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_bytes=_MAX_BYTES,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not mscclpp_is_available:
|
||||
# disable because of missing mscclpp library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "CustomAllreduce should be attached to a non-NCCL group."
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize mscclpp for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"PyMscclpp is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_mscclpp=True explicitly.",
|
||||
world_size,
|
||||
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
self.ranks = torch.distributed.get_process_group_ranks(group)
|
||||
self.nranks_per_node = torch.cuda.device_count()
|
||||
# for now mscclpp with stride in the communicator is not tested
|
||||
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
|
||||
logger.warning(
|
||||
"PyMscclpp is disabled due to an unsupported group %s."
|
||||
"Please ensure all ranks in the group are consecutive."
|
||||
"To silence this warning, specify disable_mscclpp=True explicitly.",
|
||||
str(self.ranks),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
self.max_bytes = max_bytes
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
if dist.get_rank(group) == 0:
|
||||
unique_id = [ops.mscclpp_generate_unique_id()]
|
||||
else:
|
||||
unique_id = [None]
|
||||
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
|
||||
self.unique_id = unique_id[0]
|
||||
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
|
||||
range(world_size)
|
||||
)
|
||||
for r in range(world_size):
|
||||
self.rank_to_node[r] = r // 8
|
||||
self.rank_to_ib[r] = self.rank % 8
|
||||
|
||||
self._context = None
|
||||
self.context_selection = None
|
||||
self.msg_size_for_finetune = [
|
||||
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
|
||||
]
|
||||
self.msg_size2best_config = {}
|
||||
if world_size == 8:
|
||||
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
|
||||
elif world_size == 16:
|
||||
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
|
||||
if not _is_hip:
|
||||
self.scratch = torch.empty(
|
||||
self.max_bytes * 8,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
self.put_buffer = torch.empty(
|
||||
self.max_bytes * 8 // self.nranks_per_node,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
self._context = ops.mscclpp_init_context(
|
||||
self.unique_id,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.scratch,
|
||||
self.put_buffer,
|
||||
self.nranks_per_node,
|
||||
self.rank_to_node,
|
||||
self.rank_to_ib,
|
||||
int(self.context_selection),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("HIP Mscclpp is not supported yet.")
|
||||
|
||||
self.msg_size2best_config = {}
|
||||
self.pre_tune_config()
|
||||
if dist.get_rank(group) == 0:
|
||||
msg_size2best_config = [self.msg_size2best_config]
|
||||
else:
|
||||
msg_size2best_config = [None]
|
||||
dist.broadcast_object_list(
|
||||
msg_size2best_config, src=self.ranks[0], group=self.group
|
||||
)
|
||||
self.msg_size2best_config = msg_size2best_config[0]
|
||||
|
||||
# PyMscclpp is enabled only in cuda graph
|
||||
self.disabled = True
|
||||
|
||||
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
|
||||
logger.debug(f"start to pre-tune configs for rank {self.rank}")
|
||||
nthreads_to_try = [256, 512, 1024]
|
||||
nblocks_to_try = [21, 42, 84]
|
||||
inp_randn = torch.ones(
|
||||
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
|
||||
)
|
||||
oup_randn = torch.empty_like(inp_randn)
|
||||
for msg_size in self.msg_size_for_finetune:
|
||||
mock_inp, mock_outp = (
|
||||
inp_randn[: msg_size // dtype.itemsize],
|
||||
oup_randn[: msg_size // dtype.itemsize],
|
||||
)
|
||||
best_config, best_time = None, None
|
||||
for nthreads in nthreads_to_try:
|
||||
for nblocks in nblocks_to_try:
|
||||
cur_cost = mscclpp_bench_time(
|
||||
lambda: ops.mscclpp_allreduce(
|
||||
self._context, mock_inp, mock_outp, nthreads, nblocks
|
||||
)
|
||||
)
|
||||
if best_time is None or cur_cost < best_time:
|
||||
best_config = (nthreads, nblocks)
|
||||
best_time = cur_cost
|
||||
self.msg_size2best_config[msg_size] = best_config
|
||||
if self.rank == 0:
|
||||
logger.debug(
|
||||
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
|
||||
)
|
||||
|
||||
def should_mscclpp_allreduce(
|
||||
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
|
||||
) -> bool:
|
||||
if self.disabled or self._context is None:
|
||||
return False
|
||||
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
|
||||
return False
|
||||
if not mscclpp_is_weak_contiguous(inp):
|
||||
return False
|
||||
# only support sum op
|
||||
if op != ReduceOp.SUM:
|
||||
return False
|
||||
if inp.numel() * inp.element_size() > self.max_bytes:
|
||||
return False
|
||||
return True
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
self.graph_input_set.add((tensor.dtype, tensor.numel()))
|
||||
msg_size = tensor.numel() * tensor.itemsize
|
||||
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
|
||||
msg_size_finetune = self.msg_size_for_finetune[index]
|
||||
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
|
||||
result = torch.empty_like(tensor)
|
||||
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
|
||||
return result
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self,
|
||||
enable: Optional[bool] = None,
|
||||
):
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
old_disable = self.disabled
|
||||
self.disabled = not enable
|
||||
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
@@ -190,6 +190,7 @@ class GroupCoordinator:
|
||||
cpu_group: ProcessGroup # group for CPU communication
|
||||
device_group: ProcessGroup # group for device communication
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
use_message_queue_broadcaster: (
|
||||
bool # a hint of whether to use message queue broadcaster
|
||||
@@ -205,6 +206,7 @@ class GroupCoordinator:
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_pymscclpp: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
@@ -244,6 +246,7 @@ class GroupCoordinator:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
@@ -265,6 +268,17 @@ class GroupCoordinator:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
||||
PyMscclppCommunicator,
|
||||
)
|
||||
|
||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||
if use_pymscclpp and self.world_size > 1:
|
||||
self.pymscclpp_comm = PyMscclppCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
@@ -373,11 +387,15 @@ class GroupCoordinator:
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# PyMscclpp | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the
|
||||
# tensor size is too large, it will fallback to the next
|
||||
# available option.
|
||||
# Note that the PyMsccl needs to register the tensor in ahead,
|
||||
# which will introduce large overhead in the eager case,
|
||||
# therefore it is only supported in the graph case.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using
|
||||
# CUDA graph, we use either custom all-reduce kernel or
|
||||
@@ -392,7 +410,14 @@ class GroupCoordinator:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream()
|
||||
)
|
||||
with maybe_pynccl_context:
|
||||
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
maybe_pymscclpp_context: Any
|
||||
if not pymscclpp_comm:
|
||||
maybe_pymscclpp_context = nullcontext()
|
||||
else:
|
||||
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
|
||||
with maybe_pynccl_context, maybe_pymscclpp_context:
|
||||
yield graph_capture_context
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
@@ -437,6 +462,10 @@ class GroupCoordinator:
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
and self.ca_comm.should_custom_ar(input_)
|
||||
) or (
|
||||
self.pymscclpp_comm is not None
|
||||
and not self.pymscclpp_comm.disabled
|
||||
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
||||
):
|
||||
return torch.ops.sglang.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name
|
||||
@@ -447,9 +476,13 @@ class GroupCoordinator:
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
ca_comm = self.ca_comm
|
||||
assert ca_comm is not None
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
assert ca_comm is not None or pymscclpp_comm is not None
|
||||
if ca_comm is not None and not ca_comm.disabled:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
else:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
|
||||
@@ -958,6 +991,7 @@ def init_world_group(
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=False,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
|
||||
use_custom_allreduce: Optional[bool] = None,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
if use_mscclpp_allreduce is None:
|
||||
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=not is_npu(),
|
||||
use_pymscclpp=use_mscclpp_allreduce,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
@@ -1037,6 +1075,7 @@ def graph_capture():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
||||
|
||||
|
||||
def set_custom_all_reduce(enable: bool):
|
||||
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def set_mscclpp_all_reduce(enable: bool):
|
||||
global _ENABLE_MSCCLPP_ALL_REDUCE
|
||||
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
|
||||
@@ -98,11 +98,12 @@ def initialize_dp_attention(
|
||||
],
|
||||
local_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
use_pymscclpp=False,
|
||||
use_custom_allreduce=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
group_name="attention_tp",
|
||||
)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
set_mscclpp_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
@@ -460,6 +461,7 @@ class ModelRunner:
|
||||
else:
|
||||
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
||||
|
||||
if not self.is_draft_worker:
|
||||
# Only initialize the distributed environment on the target model worker.
|
||||
|
||||
@@ -165,6 +165,7 @@ class ServerArgs:
|
||||
enable_tokenizer_batch_encode: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_mscclpp: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
@@ -1168,6 +1169,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mscclpp",
|
||||
action="store_true",
|
||||
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-schedule",
|
||||
action="store_true",
|
||||
|
||||
@@ -73,6 +73,14 @@ FetchContent_Declare(
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
# mscclpp
|
||||
FetchContent_Declare(
|
||||
repo-mscclpp
|
||||
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
|
||||
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-mscclpp)
|
||||
|
||||
# ccache option
|
||||
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
||||
@@ -99,6 +107,7 @@ include_directories(
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
)
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
|
||||
OUTPUT_VARIABLE TORCH_CXX11_ABI
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
if(TORCH_CXX11_ABI STREQUAL "0")
|
||||
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
else()
|
||||
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
endif()
|
||||
set(MSCCLPP_USE_CUDA ON)
|
||||
set(MSCCLPP_BYPASS_GPU_CHECK ON)
|
||||
set(MSCCLPP_BUILD_TESTS OFF)
|
||||
add_subdirectory(${repo-mscclpp_SOURCE_DIR})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
|
||||
@@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules
|
||||
@git submodule update --init --recursive
|
||||
|
||||
ln: submodule ## Create compilation database
|
||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES
|
||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||
|
||||
|
||||
install: submodule ## Install package in development mode
|
||||
@pip install -e . --no-build-isolation
|
||||
|
||||
build: install-deps submodule ## Build and install wheel package
|
||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
|
||||
clean: ## Remove build artifacts
|
||||
@rm -rf build dist *.egg-info
|
||||
|
||||
@@ -50,6 +50,9 @@ docker run --rm \
|
||||
which cmake
|
||||
cmake --version
|
||||
|
||||
yum install numactl-devel -y && \
|
||||
yum install libibverbs -y && \
|
||||
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \
|
||||
${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \
|
||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
|
||||
|
||||
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,140 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
enum MscclContextSelection {
|
||||
MSCCL1NODELL = 1,
|
||||
MSCCL2NODELL = 2,
|
||||
};
|
||||
|
||||
class MscclContext {
|
||||
public:
|
||||
MscclContextSelection selection_;
|
||||
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
|
||||
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
|
||||
MscclContext(MscclContextSelection selection) : selection_(selection) {}
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
if (selection_ == MSCCL1NODELL) {
|
||||
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
} else if (selection_ == MSCCL2NODELL) {
|
||||
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
|
||||
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
|
||||
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Function to convert vector of int32_t back to array of uint8_t
|
||||
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
|
||||
mscclpp::UniqueId unique_id;
|
||||
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
|
||||
return unique_id;
|
||||
}
|
||||
|
||||
torch::Tensor mscclpp_generate_unique_id() {
|
||||
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
return _unique_id2tensor(unique_id);
|
||||
}
|
||||
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
const int64_t rank,
|
||||
const int64_t world_size,
|
||||
torch::Tensor& scratch,
|
||||
torch::Tensor& put_buffer,
|
||||
const int64_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib,
|
||||
const int64_t context_selection) {
|
||||
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
|
||||
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
|
||||
if (context_selection == MSCCL1NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
|
||||
} else if (context_selection == MSCCL2NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
|
||||
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
|
||||
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
uid,
|
||||
rank,
|
||||
world_size,
|
||||
scratch_ptr,
|
||||
scratch_bytes,
|
||||
put_buffer_ptr,
|
||||
put_buffer_bytes,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
} else {
|
||||
throw std::runtime_error("invalid context selection");
|
||||
}
|
||||
return (fptr_t)context_ptr;
|
||||
}
|
||||
|
||||
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
|
||||
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
context->allreduce<float>(
|
||||
stream,
|
||||
reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
context->allreduce<half>(
|
||||
stream,
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
context->allreduce<__nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
@@ -0,0 +1,779 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
#pragma once
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include <hip/hip_fp16.h>
|
||||
#else
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
// comment this for test_mscclpp_allreduce.cu
|
||||
#include "utils.h"
|
||||
|
||||
namespace sglang {
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
|
||||
|
||||
template <typename To, typename From>
|
||||
__forceinline__ __device__ To bit_cast(const From& src) {
|
||||
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
|
||||
|
||||
union {
|
||||
From f;
|
||||
To t;
|
||||
} u;
|
||||
u.f = src;
|
||||
return u.t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T add_elements(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
|
||||
int4 ret;
|
||||
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
|
||||
uint2 ret;
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
|
||||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors(int a, int b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_1node using LLPacket, origin allreduce2
|
||||
// -------------------------------------------------------
|
||||
|
||||
__device__ uint64_t globalFlag = 1;
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
void* resultBuff,
|
||||
int rank,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeers = worldSize - 1;
|
||||
const size_t nPkts = nelems / 2;
|
||||
const int nelemsPerRank = nelems / worldSize;
|
||||
const int nPktsPerRank = nelemsPerRank / 2;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
dst[idx] = data;
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = data.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = data.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRank * nPktsPerRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx].x = data.x;
|
||||
result[idx].y = data.y;
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_2node using LLPacket, origin allreduce5
|
||||
// -------------------------------------------------------
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
mscclpp::PortChannelDeviceHandle* portChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
TYPE* putBuff,
|
||||
TYPE* resultBuff,
|
||||
int rank,
|
||||
int nRanksPerNode,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeersInNode = nRanksPerNode - 1;
|
||||
const int nPkts = nelems / 2;
|
||||
const int nelemsPerLocalRank = nelems / nRanksPerNode;
|
||||
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
|
||||
const int localRankId = rank % nRanksPerNode;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
if (nRanksPerNode > 1) {
|
||||
memChan.putPackets(
|
||||
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
}
|
||||
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
|
||||
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
const int remoteRank = index < localRankId ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
putPkt[idx].write(data.x, data.y, flag);
|
||||
dst[idx] = data;
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// step 3. send local reduced data to remote node.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
||||
if ((flag & 63) == 0) {
|
||||
portChan.flush();
|
||||
}
|
||||
}
|
||||
// step 4. try to read the data from scratch buffer and write to local peers
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 res = dst[idx];
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
res = add_vectors<TYPE>(res, val);
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = res.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = res.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
dst[idx] = res;
|
||||
}
|
||||
|
||||
// step 5: get data result from scratch buffer
|
||||
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
|
||||
if (nRanksPerNode > 1) {
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
|
||||
idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx] = data;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
static const mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0,
|
||||
mscclpp::Transport::IB1,
|
||||
mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4,
|
||||
mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6,
|
||||
mscclpp::Transport::IB7};
|
||||
|
||||
class MscclCommGroup {
|
||||
public:
|
||||
std::shared_ptr<mscclpp::Communicator> comm_;
|
||||
const size_t rank_;
|
||||
const size_t world_size_;
|
||||
const std::vector<int64_t> rank_to_node_;
|
||||
const std::vector<int64_t> rank_to_ib_;
|
||||
MscclCommGroup(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
bootstrap->initialize(unique_id);
|
||||
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
}
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
throw std::runtime_error("you should not call allreduce of a base context");
|
||||
}
|
||||
bool is_same_node(int r1, int r2) {
|
||||
return rank_to_node_[r1] == rank_to_node_[r2];
|
||||
}
|
||||
|
||||
void make_connection(
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
|
||||
same_node_connections.clear();
|
||||
cross_node_connections.clear();
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
|
||||
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
|
||||
}
|
||||
comm_->setup();
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
same_node_connections.emplace(r, conn_futures[r].get());
|
||||
} else {
|
||||
cross_node_connections.emplace(r, conn_futures[r].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void make_memory_channels_with_scratch(
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
|
||||
}
|
||||
}
|
||||
void make_port_channels_with_scratch(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::PortChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
|
||||
|
||||
mscclpp::TransportFlags flags;
|
||||
for (const auto& [_, conn] : connections) {
|
||||
flags |= conn->transport();
|
||||
}
|
||||
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
|
||||
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
|
||||
std::unordered_map<int, size_t> memory_ids;
|
||||
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
|
||||
for (const auto& [peer, memory] : registered_memories) {
|
||||
if (peer == rank_) continue;
|
||||
memory_ids[peer] = proxyService->addMemory(memory);
|
||||
}
|
||||
for (const auto& [peer, semaphore] : semaphores) {
|
||||
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
|
||||
}
|
||||
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SemaphoreType>
|
||||
void make_semaphores(
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
|
||||
semaphores.clear();
|
||||
for (const auto& [peer, conn] : connections) {
|
||||
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
|
||||
}
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
void register_tensor_with_connections(
|
||||
void* tensor_ptr,
|
||||
size_t tensor_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
|
||||
registered_memories.clear();
|
||||
mscclpp::TransportFlags all_transports;
|
||||
for (const auto& [_, connection] : connections) {
|
||||
all_transports |= connection->transport();
|
||||
}
|
||||
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
|
||||
registered_memories[rank_] = buf_reg_mem;
|
||||
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
|
||||
for (const auto& [r, connection] : connections) {
|
||||
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
|
||||
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
|
||||
remote_mem_futures.emplace(r, remoteMemory);
|
||||
}
|
||||
comm_->setup();
|
||||
for (auto& [r, mem_feature] : remote_mem_futures) {
|
||||
registered_memories.emplace(r, mem_feature.get());
|
||||
}
|
||||
}
|
||||
|
||||
void make_device_memory_handle_base_on_new_ptr(
|
||||
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
|
||||
void* input,
|
||||
void* scratch,
|
||||
const cudaStream_t stream) {
|
||||
memory_channels.clear();
|
||||
for (const auto& [peer, channel] : old_memory_channels) {
|
||||
memory_channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < world_size_; r++) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
memory_channels_list.push_back(memory_channels[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
|
||||
device_memory_handle.data(),
|
||||
memory_channel_handlers.data(),
|
||||
memory_channel_handlers.size(),
|
||||
stream,
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl1NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
public:
|
||||
Msccl1NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl1NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl2NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
void* put_buffer_;
|
||||
const size_t put_buffer_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
|
||||
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
|
||||
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
|
||||
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
|
||||
public:
|
||||
Msccl2NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
void* put_buffer,
|
||||
const size_t put_buffer_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
put_buffer_(put_buffer),
|
||||
put_buffer_bytes_(put_buffer_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1),
|
||||
d_portHandles_(world_size - nranks_per_node) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
proxyService = std::make_shared<mscclpp::ProxyService>();
|
||||
proxyService->startProxy();
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
comm_group_->make_port_channels_with_scratch(
|
||||
proxyService,
|
||||
put_buffer_,
|
||||
put_buffer_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
cross_node_connections_,
|
||||
port_semaphores_,
|
||||
registered_port_memories_,
|
||||
port_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
std::vector<mscclpp::PortChannel> port_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
} else {
|
||||
port_channels_list.push_back(port_channels_[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
|
||||
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
|
||||
std::transform(
|
||||
port_channels_list.begin(),
|
||||
port_channels_list.end(),
|
||||
port_channel_handlers.begin(),
|
||||
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl2NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
if (proxyService) {
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans,
|
||||
d_portHandles_.data(),
|
||||
(T*)input,
|
||||
(T*)scratch_,
|
||||
(T*)put_buffer_,
|
||||
output,
|
||||
comm_group_->rank_,
|
||||
nranks_per_node_,
|
||||
comm_group_->world_size_,
|
||||
input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sglang
|
||||
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* this file is used to test mscclpp_allreduce.cu using mpirun
|
||||
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
|
||||
usage:
|
||||
cd PATH-TO-THIS-FILE
|
||||
export MPI_HOME=/usr/local/mpi
|
||||
# export MPI_HOME=/opt/hpcx/ompi/
|
||||
export MSCCLPP_HOME=/workspace/test/mscclpp
|
||||
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
|
||||
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
|
||||
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
|
||||
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
|
||||
|
||||
/opt/hpcx/ompi/bin/
|
||||
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
|
||||
--map-by ppr:8:node \
|
||||
--mca btl_openib_warn_no_device_params_found 0 \
|
||||
--mca btl_tcp_if_include bond0 \
|
||||
--allow-run-as-root -np 8 \
|
||||
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
|
||||
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
|
||||
*/
|
||||
#include <mpi.h>
|
||||
#include <thrust/detail/raw_pointer_cast.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#ifndef CHECK_CUDA_SUCCESS
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
template <typename T>
|
||||
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
|
||||
return fabs(a - b) <= (atol + rtol * fabs(b));
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// init mpi
|
||||
MPI_Init(&argc, &argv);
|
||||
printf("MPI Initialized.\n");
|
||||
int nranks, rank;
|
||||
|
||||
// get work size and rank id
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
cudaSetDevice(rank);
|
||||
printf("nranks: %d, rank: %d\n", nranks, rank);
|
||||
|
||||
// init host and device buffers
|
||||
using T = float;
|
||||
using ReduceT = float;
|
||||
const size_t num_elems = 2 * 1024 * 1024;
|
||||
std::vector<T> host_buf(num_elems);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_buf[i] = T(i + rank);
|
||||
}
|
||||
thrust::device_vector<T> device_buf(host_buf);
|
||||
const size_t buf_size_in_bytes = num_elems * sizeof(T);
|
||||
std::vector<T> host_result_buf(num_elems);
|
||||
thrust::device_vector<T> device_result_buf(host_result_buf);
|
||||
|
||||
std::vector<T> host_scratch_buf(num_elems * 8);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_scratch_buf[i] = 1;
|
||||
}
|
||||
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
|
||||
std::vector<T> host_put_buf(num_elems);
|
||||
thrust::device_vector<T> device_put_buf(host_put_buf);
|
||||
|
||||
mscclpp::UniqueId unique_id;
|
||||
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
std::vector<int64_t> rank_to_node(nranks);
|
||||
std::vector<int64_t> rank_to_ib(nranks);
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
rank_to_node[i] = i / 8;
|
||||
rank_to_ib[i] = i % 8;
|
||||
}
|
||||
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
|
||||
if (nranks == 8) {
|
||||
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
} else if (nranks == 16) {
|
||||
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
|
||||
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
thrust::raw_pointer_cast(device_put_buf.data()),
|
||||
buf_size_in_bytes,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
}
|
||||
|
||||
// check result correctness
|
||||
thrust::host_vector<T> host_buf_result = device_result_buf;
|
||||
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
|
||||
bool nan_detected = false;
|
||||
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
|
||||
if (std::isnan(float(host_buf_result[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
|
||||
num_results_error_atol_1e_3_rtol_1e_3++;
|
||||
}
|
||||
}
|
||||
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
|
||||
|
||||
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
|
||||
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
|
||||
m.def(
|
||||
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
|
||||
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
|
||||
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
|
||||
|
||||
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
|
||||
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
|
||||
@@ -74,6 +74,18 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
torch::Tensor mscclpp_generate_unique_id();
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
const int64_t rank,
|
||||
const int64_t world_size,
|
||||
torch::Tensor& scratch,
|
||||
torch::Tensor& put_buffer,
|
||||
const int64_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib,
|
||||
const int64_t context_selection);
|
||||
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks);
|
||||
#endif
|
||||
|
||||
/*
|
||||
|
||||
@@ -49,6 +49,27 @@ if torch.version.hip is not None:
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
|
||||
def init_custom_ar(
|
||||
@@ -85,3 +106,36 @@ else:
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def mscclpp_generate_unique_id() -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: torch.Tensor,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.mscclpp_init_context.default(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.mscclpp_allreduce.default(
|
||||
context, inp, out, nthreads, nblocks
|
||||
)
|
||||
|
||||
146
sgl-kernel/tests/test_mscclpp.py
Normal file
146
sgl-kernel/tests/test_mscclpp.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import sgl_kernel.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class MscclContextSelection(IntEnum):
|
||||
MSCCL1SHOT1NODELL = 1
|
||||
MSCCL1SHOT2NODELL = 2
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
if rank == 0:
|
||||
unique_id = [custom_ops.mscclpp_generate_unique_id()]
|
||||
else:
|
||||
unique_id = [None]
|
||||
dist.broadcast_object_list(
|
||||
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
|
||||
)
|
||||
unique_id = unique_id[0]
|
||||
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
|
||||
for r in range(world_size):
|
||||
rank_to_node[r] = r // 8
|
||||
rank_to_ib[r] = rank % 8
|
||||
MAX_BYTES = 2**20
|
||||
scratch = torch.empty(
|
||||
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
put_buffer = torch.empty(
|
||||
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
print(f"[{rank}] start mscclpp_context init")
|
||||
nranks_per_node = torch.cuda.device_count()
|
||||
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
|
||||
mscclpp_context = custom_ops.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
selection,
|
||||
)
|
||||
try:
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if sz * dtype.itemsize > MAX_BYTES:
|
||||
continue
|
||||
if rank == 0:
|
||||
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
custom_ops.mscclpp_allreduce(
|
||||
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
|
||||
)
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestMSCCLAllReduce(unittest.TestCase):
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
]
|
||||
world_sizes = [8]
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
205
test/srt/test_mscclpp.py
Normal file
205
test/srt/test_mscclpp.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""For Now, MSCCL is only supported on TP16 and TP8 case
|
||||
|
||||
if [[ $RANK -eq 0 ]]; then
|
||||
ray start --block --head --port=6379 &
|
||||
python3 test_mscclpp.py;
|
||||
else
|
||||
ray start --block --address=${MASTER_ADDR}:6379;
|
||||
fi
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import unittest
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
set_mscclpp_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.utils import StatelessProcessGroup
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
cls: Any,
|
||||
test_target: Any,
|
||||
) -> None:
|
||||
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
# NOTE: We need to set working_dir for distributed tests,
|
||||
# otherwise we may get import errors on ray workers
|
||||
|
||||
ray.init(log_to_driver=True)
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(world_size):
|
||||
refs.append(
|
||||
test_target.remote(
|
||||
cls, world_size, master_addr, rank, distributed_init_port
|
||||
)
|
||||
)
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class TestMSCCLAllReduce(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
random.seed(42)
|
||||
# 1KB to 1MB
|
||||
cls.test_sizes = [512, 4096, 32768, 262144, 524288]
|
||||
cls.world_sizes = [8]
|
||||
TEST_TP16 = int(os.getenv("SGL_MSCCLPP_TEST_TP16", "0"))
|
||||
if TEST_TP16:
|
||||
cls.world_sizes = [16]
|
||||
cls.test_loop = 10
|
||||
|
||||
def test_graph_allreduce(self):
|
||||
TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
|
||||
for world_size in self.world_sizes:
|
||||
if world_size not in [8, 16]:
|
||||
continue
|
||||
multi_process_parallel(
|
||||
world_size, TEST_MASTER_ADDR, self, self.graph_allreduce
|
||||
)
|
||||
|
||||
def test_eager_allreduce(self):
|
||||
TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
|
||||
for world_size in self.world_sizes:
|
||||
if world_size not in [8, 16]:
|
||||
continue
|
||||
multi_process_parallel(
|
||||
world_size, TEST_MASTER_ADDR, self, self.eager_allreduce
|
||||
)
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(self, world_size, master_addr, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
|
||||
set_mscclpp_all_reduce(True)
|
||||
set_custom_all_reduce(False)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=rank % torch.cuda.device_count(),
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
# (e.g. NCCL). This will ensure that the communicator is initialized
|
||||
# before any communication happens, so that this group can be used for
|
||||
# graph capture immediately.
|
||||
data = torch.zeros(1)
|
||||
data = data.to(device=device)
|
||||
torch.distributed.all_reduce(data, group=group)
|
||||
torch.cuda.synchronize()
|
||||
del data
|
||||
|
||||
for sz in self.test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(self.test_loop):
|
||||
with graph_capture() as graph_capture_context:
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(
|
||||
1,
|
||||
16,
|
||||
(sz,),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
inp2 = torch.randint(
|
||||
1,
|
||||
16,
|
||||
(sz,),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(
|
||||
graph, stream=graph_capture_context.stream
|
||||
):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1, group=group)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
torch.testing.assert_close(out2, inp2)
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(self, world_size, master_addr, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
|
||||
set_mscclpp_all_reduce(True)
|
||||
set_custom_all_reduce(False)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=rank,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
for sz in self.test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(self.test_loop):
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
dist.all_reduce(inp1, group=group)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user