From 6cb3974e77524ee2b291919ca6a8b55547bca8e0 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Thu, 16 Jan 2025 03:04:25 +0800 Subject: [PATCH] optimize custom allreduce kernel (#2904) --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/setup.py | 2 +- sgl-kernel/src/sgl-kernel/__init__.py | 4 + .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 10 +- .../sgl-kernel/csrc/trt_reduce_internal.cu | 139 +++++++++++------- .../sgl-kernel/csrc/trt_reduce_internal.cuh | 8 +- .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 117 ++++++++++++++- sgl-kernel/src/sgl-kernel/ops/__init__.py | 20 ++- sgl-kernel/tests/test_trt_reduce.py | 22 +-- 9 files changed, 244 insertions(+), 80 deletions(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index b03b4c02b..6a6a0d1fe 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post12" +version = "0.0.2.post13" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 83025d6d6..2d2d9258a 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -40,7 +40,7 @@ nvcc_flags = [ "-U__CUDA_NO_HALF2_OPERATORS__", ] cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python"] +libraries = ["c10", "torch", "torch_python", "cuda"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] ext_modules = [ CUDAExtension( diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 62c366731..0c744982d 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,9 +1,11 @@ from sgl_kernel.ops import ( custom_dispose, custom_reduce, + get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, moe_align_block_size, + register_graph_buffers, sampling_scaling_penalties, ) @@ -14,4 +16,6 @@ __all__ = [ "custom_reduce", "int8_scaled_mm", "sampling_scaling_penalties", + "get_graph_buffer_ipc_meta", + "register_graph_buffers", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index fbfe51442..b9879b114 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -2,10 +2,14 @@ // trt_reduce using fptr_t = int64_t; -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out); +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out); void dispose(fptr_t _fa); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets); // moe_align_block_size void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, @@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); m.def("dispose", &dispose, "dispose custom allreduce meta"); m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); + m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta"); + m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers"); // moe_align_block_size m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); // sampling_scaling_penalties diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index a6f2d5216..006c3200d 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const __syncthreads(); } +template __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx, int const grid_size, - bool start = true, bool need_fence = false) { - if (!start) { + size_t const world_size, int const tidx, int const bidx, int const grid_size) { + if constexpr (!start) { __syncthreads(); } // After this function, the block of id == bidx of each GPU has reached the barrier @@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag // Block broadcast its flag (local_rank on emitting dimension) to all receivers uint32_t flag_block_offset = world_size + bidx * world_size; - if (flag % 2 == 1) { - flag_block_offset += (grid_size + 1) * world_size; - } + flag_block_offset += (grid_size + 1) * world_size * (flag % 2); - if (need_fence) { - st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); - } else { - st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); - } - // Blocks check that corresponding blocks on other GPUs have also set the flag uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; - - if (need_fence) { + // Blocks check that corresponding blocks on other GPUs have also set the flag + if constexpr (need_fence) { + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); while (ld_flag_acquire(peer_barrier_d) != flag) { } } else { + st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); while (ld_flag_volatile(peer_barrier_d) != flag) { } } @@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag __syncthreads(); } -template /* COPY_INPUT = false, PUSH_MODE = false */ +template static __global__ void oneShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start four blocks. // The message is partitioned into chunks as detailed below: @@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { int const bidx = blockIdx.x; int const tidx = threadIdx.x; + int const grid_size = gridDim.x; // The number of elements packed into one for comms static constexpr int NUM_ELTS = 16 / sizeof(T); @@ -201,18 +196,23 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { using PackedStruct = typename PackedOn16Bytes::Type; // The source pointers. Distributed round-robin for the different warps. - T const* buffers[RANKS_PER_NODE]; - + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); // Start and end offsets of the thread size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); -#pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - int rank = (params.local_rank + ii) % RANKS_PER_NODE; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); - } - multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + if constexpr (COPY_INPUT) { + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + // Copy from local buffer to shareable buffer + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + *reinterpret_cast(&local_shared_buffer[iter_offset]) = + *reinterpret_cast(&local_input_buffer[iter_offset]); + } + } + // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { @@ -220,7 +220,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { PackedStruct vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + vals[ii].packed = *reinterpret_cast(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]); } // Sum the values from the different ranks. @@ -229,8 +229,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { #pragma unroll for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + sums.packed = add128b(sums, vals[rank]); } // Store to the destination buffer. @@ -238,7 +237,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { } } -template +template static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start two blocks. // The message is partitioned into chunks as detailed below: @@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc static constexpr int PACKED_ELTS = 16 / sizeof(T); using PackedType = typename PackedOn16Bytes::Type; - T* local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); T* buffers[RANKS_PER_NODE]; + T* buffers_unorder[RANKS_PER_NODE]; int ranks[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { // A mapping of the ranks to scatter reads as much as possible int rank = (params.local_rank + ii) % RANKS_PER_NODE; ranks[ii] = rank; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + buffers[ii] = reinterpret_cast(peer_comm_buffer_ptrs[rank]); + buffers_unorder[ii] = reinterpret_cast(peer_comm_buffer_ptrs[ii]); } #if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) @@ -308,8 +311,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc #endif #endif - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + if constexpr (COPY_INPUT) { + // Copy all blocks from local buffer to shareable buffer + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + *reinterpret_cast(&local_shared_buffer[offset_rank]) = + *reinterpret_cast(&local_input_buffer[offset_rank]); + } + } + } + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc PackedType vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&buffers[ii][responsible_block_offset]); + vals[ii].packed = *reinterpret_cast(&buffers_unorder[ii][responsible_block_offset]); } // Sum the values from the different ranks. @@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc #pragma unroll for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + sums.packed = add128b(sums, vals[rank]); } - // Store to the local buffer. - *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + // Store to the local buffer or tmp buffer + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + } else { + *reinterpret_cast(¶ms.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed; + } } - block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size, false, true); + block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, + bidx, grid_size); // Gather all needed elts from other intra-node ranks for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc if (offset_rank >= params.elts_total) { continue; } - - *reinterpret_cast(&local_output_buffer[offset_rank]) = *reinterpret_cast(&buffers[ii][offset_rank]); + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(&buffers[ii][offset_rank]); + } else { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(¶ms.tmp_result_buffers[ranks[ii]][offset_rank]); + } } } #if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) @@ -417,48 +442,50 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream) { switch (algo) { case AllReduceStrategyType::ONESHOT: { - oneShotAllReduceKernel<<>>(param); + oneShotAllReduceKernel<<>>(param); break; } case AllReduceStrategyType::TWOSHOT: { - twoShotAllReduceKernel<<>>(param); + twoShotAllReduceKernel<<>>(param); break; } } } -template -void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { - void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); - void* local_inp_buffer = param.local_input_buffer_ptr; - CHECK_CUDA_SUCCESS( - cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); - - CHECK_CUDA_SUCCESS(cudaGetLastError()); - +template +void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) { size_t elts_per_thread = 16 / sizeof(T); auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); switch (param.ranks_per_node) { case 2: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 4: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 6: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 8: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; default: break; } +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { + if (param.is_capturing) { + dispatchARKernelsCopyInput(strat, param, stream); + } else { + dispatchARKernelsCopyInput(strat, param, stream); + } CHECK_CUDA_SUCCESS(cudaGetLastError()); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 1c7c714dc..9d6f9722e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t { AUTO = 3, }; +struct RankData { + void* ptrs[MAX_RANKS_PER_NODE]; +}; + struct AllReduceParams { size_t elts_size; size_t elts_total; @@ -46,9 +50,11 @@ struct AllReduceParams { uint32_t barrier_flag; uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; - void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE]; + RankData* peer_comm_buffer_ptrs; void* local_input_buffer_ptr; void* local_output_buffer_ptr; + bool is_capturing; }; inline size_t GetMaxRequiredWorkspaceSize(int world_size) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index 59b548c77..d80beedec 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -12,25 +12,46 @@ using namespace trt_llm; using fptr_t = int64_t; +using IPC_KEY = std::array; class AllReduceMeta { public: - AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out) { + AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { this->rank_id = (int)rank_id; this->world_size = (int)world_size; - this->buffers = buffers; this->barrier_in = barrier_in; this->barrier_out = barrier_out; + this->tmp_result_buffers = tmp_result_buffers; + + this->rank_data_base = reinterpret_cast(rank_data.data_ptr()); + RankData data; + for (int i = 0; i < world_size; i++) { + data.ptrs[i] = (void*)buffers[i]; + } + auto d_data = this->rank_data_base++; + CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + this->buffers = d_data; + } + + ~AllReduceMeta() { + for (auto [_, ptr] : ipc_handles_) { + CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr)); + } } public: int world_size; int rank_id; - std::vector buffers; std::vector barrier_in; std::vector barrier_out; + std::vector tmp_result_buffers; int barrier_flag = 1; + RankData* buffers; + RankData* rank_data_base; + std::vector graph_unreg_buffers; + std::map ipc_handles_; }; // Get the number of bits for a given data type. @@ -52,9 +73,10 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out) { - auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); return (fptr_t)m; } @@ -63,6 +85,75 @@ void dispose(fptr_t _fa) { delete fa; } +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto num_buffers = m->graph_unreg_buffers.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = m->graph_unreg_buffers[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) { + assert(false && "failed to get pointer attr"); + } + + CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + std::vector bytes(handles.begin(), handles.end()); + return std::make_pair(bytes, offsets); +} + +char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { + auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; +} + +// Note: when registering graph buffers, we intentionally choose to not +// deduplicate the addresses. That means if the allocator reuses some +// addresses, they will be registered again. This is to account for the remote +// possibility of different allocation patterns between ranks. For example, +// rank 1 may get the same input address for the second allreduce, but rank 2 +// got a different address. IPC handles have internal reference counting +// mechanism so overhead should be small. +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets) { + AllReduceMeta* m = reinterpret_cast(_fa); + std::vector handle_bytes; + handle_bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + handle_bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + auto num_buffers = m->graph_unreg_buffers.size(); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = m->graph_unreg_buffers[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < m->world_size; j++) { + if (j != m->rank_id) { + char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CHECK_CUDA_SUCCESS( + cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); + m->rank_data_base += num_buffers; + m->graph_unreg_buffers.clear(); +} + void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { AllReduceMeta* m = reinterpret_cast(_fa); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -87,8 +178,18 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { params.elts_size = inp.element_size(); params.barrier_flag = ++(m->barrier_flag); + cudaStreamCaptureStatus status; + CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status)); + params.is_capturing = (status == cudaStreamCaptureStatusActive); + if (params.is_capturing) { + params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size(); + m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr); + } else { + params.peer_comm_buffer_ptrs = m->buffers; + } + for (int i = 0; i < world_size; ++i) { - params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); + params.tmp_result_buffers[i] = reinterpret_cast(m->tmp_result_buffers[i]); } for (int i = 0; i < world_size; ++i) { params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 03a8db80f..6b35f78a4 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,15 +1,23 @@ from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import dispose as _dispose +from sgl_kernel.ops._kernels import ( + get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta, +) from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size +from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) -def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): - return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) +def init_custom_reduce( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out +): + return _init_custom_ar( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out + ) def custom_dispose(fa): @@ -20,6 +28,14 @@ def custom_reduce(fa, inp, out): _all_reduce(fa, inp, out) +def get_graph_buffer_ipc_meta(fa): + return _get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa, handles, offsets): + _register_graph_buffers(fa, handles, offsets) + + def moe_align_block_size( topk_ids, num_experts, diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py index a5ce1b41d..b79580070 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union import ray import torch import torch.distributed as dist +from sgl_kernel import ops as custom_ops from torch.distributed import ProcessGroup from vllm import _custom_ops as vllm_ops @@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase): multi_process_parallel(world_size, self, self.performance) def init_custom_allreduce(self, rank, world_size, group): - import sgl_kernel - buffer_max_size = 8 * 1024 * 1024 barrier_max_size = 8 * (24 + 2) * 8 self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + buffer_max_size, group=group + ) self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) - self.custom_ptr = sgl_kernel.ops.init_custom_reduce( + self.custom_ptr = custom_ops.init_custom_reduce( rank, world_size, + self.rank_data, self.buffer_ptrs, + self.tmp_result_buffer_ptrs, self.barrier_in_ptrs, self.barrier_out_ptrs, ) def custom_allreduce(self, inp, out): - import sgl_kernel - - sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out) + custom_ops.custom_reduce(self.custom_ptr, inp, out) def free_custom_allreduce(self, group): - import sgl_kernel - self.free_shared_buffer(self.buffer_ptrs, group) + self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) self.free_shared_buffer(self.barrier_in_ptrs, group) self.free_shared_buffer(self.barrier_out_ptrs, group) - sgl_kernel.ops.custom_dispose(self.custom_ptr) + custom_ops.custom_dispose(self.custom_ptr) def init_vllm_allreduce(self, rank, group): self.vllm_rank = rank