support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)

This commit is contained in:
zyksir
2025-06-05 13:11:24 +08:00
committed by GitHub
parent 4474eaf552
commit 8e3797be1c
20 changed files with 2177 additions and 12 deletions

View File

@@ -73,6 +73,14 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
# mscclpp
FetchContent_Declare(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-mscclpp)
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
@@ -99,6 +107,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
)
set(SGL_KERNEL_CUDA_FLAGS
@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
else()
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif()
set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF)
add_subdirectory(${repo-mscclpp_SOURCE_DIR})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD

View File

@@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules
@git submodule update --init --recursive
ln: submodule ## Create compilation database
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
install: submodule ## Install package in development mode
@pip install -e . --no-build-isolation
build: install-deps submodule ## Build and install wheel package
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
clean: ## Remove build artifacts
@rm -rf build dist *.egg-info

View File

@@ -50,6 +50,9 @@ docker run --rm \
which cmake
cmake --version
yum install numactl-devel -y && \
yum install libibverbs -y && \
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \
${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \

View File

@@ -0,0 +1,140 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum MscclContextSelection {
MSCCL1NODELL = 1,
MSCCL2NODELL = 2,
};
class MscclContext {
public:
MscclContextSelection selection_;
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
MscclContext(MscclContextSelection selection) : selection_(selection) {}
template <typename T>
void allreduce(
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
if (selection_ == MSCCL1NODELL) {
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
} else if (selection_ == MSCCL2NODELL) {
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
}
}
};
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
return tensor;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
mscclpp::UniqueId unique_id;
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
return unique_id;
}
torch::Tensor mscclpp_generate_unique_id() {
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
return _unique_id2tensor(unique_id);
}
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection) {
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
if (context_selection == MSCCL1NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
} else if (context_selection == MSCCL2NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
uid,
rank,
world_size,
scratch_ptr,
scratch_bytes,
put_buffer_ptr,
put_buffer_bytes,
nranks_per_node,
rank_to_node,
rank_to_ib);
} else {
throw std::runtime_error("invalid context selection");
}
return (fptr_t)context_ptr;
}
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
context->allreduce<float>(
stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
case at::ScalarType::Half: {
context->allreduce<half>(
stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
context->allreduce<__nv_bfloat16>(
stream,
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}

View File

@@ -0,0 +1,779 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
// comment this for test_mscclpp_allreduce.cu
#include "utils.h"
namespace sglang {
__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
union {
From f;
To t;
} u;
u.f = src;
return u.t;
}
template <typename T>
__forceinline__ __device__ T add_elements(T a, T b) {
return a + b;
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
return __hadd2(a, b);
}
#endif
template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
template <typename T>
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
return ret;
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T>
__forceinline__ __device__ int add_vectors(int a, int b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}
// -------------------------------------------------------
// allreduce_LL_1node using LLPacket, origin allreduce2
// -------------------------------------------------------
__device__ uint64_t globalFlag = 1;
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
mscclpp::MemoryChannelDeviceHandle* memChans,
TYPE* buff,
TYPE* scratch,
void* resultBuff,
int rank,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeers = worldSize - 1;
const size_t nPkts = nelems / 2;
const int nelemsPerRank = nelems / worldSize;
const int nPktsPerRank = nelemsPerRank / 2;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
// step 1: write to scratch buffer
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
dst[idx] = data;
mscclpp::LLPacket packet;
packet.data1 = data.x;
packet.flag1 = flag;
packet.data2 = data.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
memChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx].x = data.x;
result[idx].y = data.y;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
// -------------------------------------------------------
// allreduce_LL_2node using LLPacket, origin allreduce5
// -------------------------------------------------------
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
mscclpp::MemoryChannelDeviceHandle* memChans,
mscclpp::PortChannelDeviceHandle* portChans,
TYPE* buff,
TYPE* scratch,
TYPE* putBuff,
TYPE* resultBuff,
int rank,
int nRanksPerNode,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeersInNode = nRanksPerNode - 1;
const int nPkts = nelems / 2;
const int nelemsPerLocalRank = nelems / nRanksPerNode;
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
const int localRankId = rank % nRanksPerNode;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
// step 1: write to scratch buffer
if (nRanksPerNode > 1) {
memChan.putPackets(
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
}
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeersInNode; index++) {
const int remoteRank = index < localRankId ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
putPkt[idx].write(data.x, data.y, flag);
dst[idx] = data;
}
deviceSyncer.sync(gridDim.x);
// step 3. send local reduced data to remote node.
if (threadIdx.x == 0 && blockIdx.x == 0) {
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
if ((flag & 63) == 0) {
portChan.flush();
}
}
// step 4. try to read the data from scratch buffer and write to local peers
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 res = dst[idx];
uint2 val = dstPkt[idx].read(flag);
res = add_vectors<TYPE>(res, val);
mscclpp::LLPacket packet;
packet.data1 = res.x;
packet.flag1 = flag;
packet.data2 = res.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
for (int index = 0; index < nPeersInNode; index++) {
memChans[index].write(offset, packet);
}
dst[idx] = res;
}
// step 5: get data result from scratch buffer
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
if (nRanksPerNode > 1) {
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx] = data;
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
static const mscclpp::Transport IBs[] = {
mscclpp::Transport::IB0,
mscclpp::Transport::IB1,
mscclpp::Transport::IB2,
mscclpp::Transport::IB3,
mscclpp::Transport::IB4,
mscclpp::Transport::IB5,
mscclpp::Transport::IB6,
mscclpp::Transport::IB7};
class MscclCommGroup {
public:
std::shared_ptr<mscclpp::Communicator> comm_;
const size_t rank_;
const size_t world_size_;
const std::vector<int64_t> rank_to_node_;
const std::vector<int64_t> rank_to_ib_;
MscclCommGroup(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(unique_id);
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
}
template <typename T>
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
throw std::runtime_error("you should not call allreduce of a base context");
}
bool is_same_node(int r1, int r2) {
return rank_to_node_[r1] == rank_to_node_[r2];
}
void make_connection(
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
same_node_connections.clear();
cross_node_connections.clear();
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
}
comm_->setup();
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
same_node_connections.emplace(r, conn_futures[r].get());
} else {
cross_node_connections.emplace(r, conn_futures[r].get());
}
}
}
void make_memory_channels_with_scratch(
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
for (const auto& [peer, _] : connections) {
channels.emplace(
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
}
}
void make_port_channels_with_scratch(
std::shared_ptr<mscclpp::ProxyService> proxyService,
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::PortChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
mscclpp::TransportFlags flags;
for (const auto& [_, conn] : connections) {
flags |= conn->transport();
}
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
std::unordered_map<int, size_t> memory_ids;
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
for (const auto& [peer, memory] : registered_memories) {
if (peer == rank_) continue;
memory_ids[peer] = proxyService->addMemory(memory);
}
for (const auto& [peer, semaphore] : semaphores) {
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
}
for (const auto& [peer, _] : connections) {
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
}
}
template <typename SemaphoreType>
void make_semaphores(
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
semaphores.clear();
for (const auto& [peer, conn] : connections) {
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
}
comm_->setup();
}
void register_tensor_with_connections(
void* tensor_ptr,
size_t tensor_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
registered_memories.clear();
mscclpp::TransportFlags all_transports;
for (const auto& [_, connection] : connections) {
all_transports |= connection->transport();
}
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
registered_memories[rank_] = buf_reg_mem;
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
for (const auto& [r, connection] : connections) {
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
remote_mem_futures.emplace(r, remoteMemory);
}
comm_->setup();
for (auto& [r, mem_feature] : remote_mem_futures) {
registered_memories.emplace(r, mem_feature.get());
}
}
void make_device_memory_handle_base_on_new_ptr(
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
void* input,
void* scratch,
const cudaStream_t stream) {
memory_channels.clear();
for (const auto& [peer, channel] : old_memory_channels) {
memory_channels.emplace(
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
}
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < world_size_; r++) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
memory_channels_list.push_back(memory_channels[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
device_memory_handle.data(),
memory_channel_handlers.data(),
memory_channel_handlers.size(),
stream,
cudaMemcpyHostToDevice);
}
};
class Msccl1NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
public:
Msccl1NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
memory_channels_list.push_back(memory_channels_[r]);
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl1NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
}
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
class Msccl2NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
void* put_buffer_;
const size_t put_buffer_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
std::shared_ptr<mscclpp::ProxyService> proxyService;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
public:
Msccl2NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
void* put_buffer,
const size_t put_buffer_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
put_buffer_(put_buffer),
put_buffer_bytes_(put_buffer_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1),
d_portHandles_(world_size - nranks_per_node) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
proxyService = std::make_shared<mscclpp::ProxyService>();
proxyService->startProxy();
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
comm_group_->make_port_channels_with_scratch(
proxyService,
put_buffer_,
put_buffer_bytes_,
scratch_,
scratch_bytes_,
cross_node_connections_,
port_semaphores_,
registered_port_memories_,
port_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
std::vector<mscclpp::PortChannel> port_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
memory_channels_list.push_back(memory_channels_[r]);
} else {
port_channels_list.push_back(port_channels_[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
std::transform(
port_channels_list.begin(),
port_channels_list.end(),
port_channel_handlers.begin(),
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl2NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
if (proxyService) {
proxyService->stopProxy();
}
}
template <typename T>
void
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
memChans,
d_portHandles_.data(),
(T*)input,
(T*)scratch_,
(T*)put_buffer_,
output,
comm_group_->rank_,
nranks_per_node_,
comm_group_->world_size_,
input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
} // namespace sglang

View File

@@ -0,0 +1,153 @@
/*
* this file is used to test mscclpp_allreduce.cu using mpirun
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
usage:
cd PATH-TO-THIS-FILE
export MPI_HOME=/usr/local/mpi
# export MPI_HOME=/opt/hpcx/ompi/
export MSCCLPP_HOME=/workspace/test/mscclpp
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
/opt/hpcx/ompi/bin/
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
--map-by ppr:8:node \
--mca btl_openib_warn_no_device_params_found 0 \
--mca btl_tcp_if_include bond0 \
--allow-run-as-root -np 8 \
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
*/
#include <mpi.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#ifndef CHECK_CUDA_SUCCESS
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif
#include <cstdint>
#include "mscclpp_allreduce.cuh"
template <typename T>
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
return fabs(a - b) <= (atol + rtol * fabs(b));
}
int main(int argc, char* argv[]) {
// init mpi
MPI_Init(&argc, &argv);
printf("MPI Initialized.\n");
int nranks, rank;
// get work size and rank id
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
cudaSetDevice(rank);
printf("nranks: %d, rank: %d\n", nranks, rank);
// init host and device buffers
using T = float;
using ReduceT = float;
const size_t num_elems = 2 * 1024 * 1024;
std::vector<T> host_buf(num_elems);
for (uint32_t i = 0; i < num_elems; ++i) {
host_buf[i] = T(i + rank);
}
thrust::device_vector<T> device_buf(host_buf);
const size_t buf_size_in_bytes = num_elems * sizeof(T);
std::vector<T> host_result_buf(num_elems);
thrust::device_vector<T> device_result_buf(host_result_buf);
std::vector<T> host_scratch_buf(num_elems * 8);
for (uint32_t i = 0; i < num_elems; ++i) {
host_scratch_buf[i] = 1;
}
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
std::vector<T> host_put_buf(num_elems);
thrust::device_vector<T> device_put_buf(host_put_buf);
mscclpp::UniqueId unique_id;
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
std::vector<int64_t> rank_to_node(nranks);
std::vector<int64_t> rank_to_ib(nranks);
for (int i = 0; i < nranks; i++) {
rank_to_node[i] = i / 8;
rank_to_ib[i] = i % 8;
}
cudaStream_t s;
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
if (nranks == 8) {
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
} else if (nranks == 16) {
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
thrust::raw_pointer_cast(device_put_buf.data()),
buf_size_in_bytes,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
}
// check result correctness
thrust::host_vector<T> host_buf_result = device_result_buf;
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
bool nan_detected = false;
for (uint32_t i = 0; i < num_elems; ++i) {
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
if (std::isnan(float(host_buf_result[i]))) {
nan_detected = true;
}
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
num_results_error_atol_1e_3_rtol_1e_3++;
}
}
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
MPI_Finalize();
return 0;
}

View File

@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
m.def(
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
/*
* From csrc/attention
*/

View File

@@ -74,6 +74,18 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor mscclpp_generate_unique_id();
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection);
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks);
#endif
/*

View File

@@ -49,6 +49,27 @@ if torch.version.hip is not None:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
raise NotImplementedError()
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
raise NotImplementedError()
else:
def init_custom_ar(
@@ -85,3 +106,36 @@ else:
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def mscclpp_generate_unique_id() -> torch.Tensor:
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
def mscclpp_init_context(
unique_id: torch.Tensor,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return torch.ops.sgl_kernel.mscclpp_init_context.default(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
torch.ops.sgl_kernel.mscclpp_allreduce.default(
context, inp, out, nthreads, nblocks
)

View File

@@ -0,0 +1,146 @@
import multiprocessing as mp
import os
import socket
import unittest
from enum import IntEnum
from typing import Any
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if rank == 0:
unique_id = [custom_ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
)
unique_id = unique_id[0]
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
for r in range(world_size):
rank_to_node[r] = r // 8
rank_to_ib[r] = rank % 8
MAX_BYTES = 2**20
scratch = torch.empty(
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
put_buffer = torch.empty(
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
print(f"[{rank}] start mscclpp_context init")
nranks_per_node = torch.cuda.device_count()
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
mscclpp_context = custom_ops.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
selection,
)
try:
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if sz * dtype.itemsize > MAX_BYTES:
continue
if rank == 0:
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.mscclpp_allreduce(
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestMSCCLAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
]
world_sizes = [8]
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()