From 3900a94afe635bccab3975852cdfa8d4ffd8fce1 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Mon, 6 Jan 2025 00:47:16 +0800 Subject: [PATCH] Support twoshot kernel (#2688) --- sgl-kernel/pyproject.toml | 2 +- .../sgl-kernel/csrc/trt_reduce_internal.cu | 206 +++++++++++++++++- .../sgl-kernel/csrc/trt_reduce_internal.cuh | 14 +- .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 2 +- sgl-kernel/tests/test_trt_reduce.py | 13 +- 5 files changed, 216 insertions(+), 21 deletions(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 54582a787..359ffafd7 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post10" +version = "0.0.2.post11" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 04393c8e7..b4d17ded1 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { return flag; } +static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + namespace trt_llm { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,6 +126,45 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const __syncthreads(); } +__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, + size_t const world_size, int const tidx, int const bidx, int const grid_size, + bool start = true, bool need_fence = false) { + if (!start) { + __syncthreads(); + } + // After this function, the block of id == bidx of each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, 2, num_blocks, world_size] + // (+ an offset on dim 2 to account for flags used in multi_gpu_barrier) + // Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension + + // Block broadcast its flag (local_rank on emitting dimension) to all receivers + uint32_t flag_block_offset = world_size + bidx * world_size; + + if (flag % 2 == 1) { + flag_block_offset += (grid_size + 1) * world_size; + } + + if (need_fence) { + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); + } else { + st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); + } + // Blocks check that corresponding blocks on other GPUs have also set the flag + uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; + + if (need_fence) { + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } else { + while (ld_flag_volatile(peer_barrier_d) != flag) { + } + } + } + + __syncthreads(); +} + template /* COPY_INPUT = false, PUSH_MODE = false */ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start four blocks. @@ -189,6 +238,124 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { } } +template +static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start two blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // |--GPU 0--|--GPU 1--| (GPU responsibility parts) + // GPU 0 | B0 | B1 | B0 | B1 | + // GPU 1 | B0 | B1 | B0 | B1 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0) + // 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility + // part (the first half of the message, see GPU responsibility row above) + // 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0, + // where GPU 1 is responsible: the second half of the message. + // 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1) + // 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU. + // For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1. + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // to be read. + // + // Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks. + // However, it's only responsible for the summation of a single chunk. + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunks is it responsible for into the corresponding GPUs: + // params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice] + // 2. block sync so the blocks have been shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + // 4. block barrier (corresponding blocks have finished reduction) + // 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is + // written at index 0 of 2nd dim) + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + int const grid_size = gridDim.x; + + // The number of elements packed into one for comms + static constexpr int PACKED_ELTS = 16 / sizeof(T); + using PackedType = typename PackedOn16Bytes::Type; + + T* local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); + T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); + + size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; + size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); + + T* buffers[RANKS_PER_NODE]; + int ranks[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // A mapping of the ranks to scatter reads as much as possible + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + ranks[ii] = rank; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { + size_t const responsible_block_offset = local_offset + params.rank_offset; + + // Iterate over the different ranks/devices on the node to load the values. + PackedType vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&buffers[ii][responsible_block_offset]); + } + + // Sum the values from the different ranks. + PackedType sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); + } + + // Store to the local buffer. + *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + } + + block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size, false, true); + + // Gather all needed elts from other intra-node ranks + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // use round-robin gathering from other ranks + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + + *reinterpret_cast(&local_output_buffer[offset_rank]) = *reinterpret_cast(&buffers[ii][offset_rank]); + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int divUp(int a, int b) { @@ -211,6 +378,33 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar params.elts_per_rank = params.elts_total; break; } + case AllReduceStrategyType::TWOSHOT: { + assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0); + size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE); + + /* + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + */ + while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { + blocks_per_grid += 1; + } + + threads_per_block = total_threads / blocks_per_grid; + + // NOTE: need to adjust here + if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) { + size_t iter_factor = 1; + while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { + iter_factor += 1; + } + blocks_per_grid /= iter_factor; + } + params.elts_per_rank = params.elts_total / params.ranks_per_node; + params.rank_offset = params.local_rank * params.elts_per_rank; + params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread); + break; + } default: assert(false && "Algorithm not supported here."); } @@ -223,7 +417,16 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar template void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream) { - oneShotAllReduceKernel<<>>(param); + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + oneShotAllReduceKernel<<>>(param); + break; + } + case AllReduceStrategyType::TWOSHOT: { + twoShotAllReduceKernel<<>>(param); + break; + } + } } template @@ -233,7 +436,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy CHECK_CUDA_SUCCESS( cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); - assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); CHECK_CUDA_SUCCESS(cudaGetLastError()); size_t elts_per_thread = 16 / sizeof(T); diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 01652a22a..1c7c714dc 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -25,9 +25,9 @@ namespace trt_llm { constexpr size_t WARP_SIZE = 32; -constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; constexpr size_t MAX_RANKS_PER_NODE = 8; -constexpr size_t DEFAULT_BLOCK_SIZE = 1024; +constexpr size_t DEFAULT_BLOCK_SIZE = 512; enum class AllReduceStrategyType : int8_t { RING = 0, @@ -53,9 +53,9 @@ struct AllReduceParams { inline size_t GetMaxRequiredWorkspaceSize(int world_size) { if (world_size <= 2) { - return 16 * 1000 * 1000; + return 16 * 1024 * 1024; } - return 8 * 1000 * 1000; + return 8 * 1024 * 1024; } inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) { @@ -71,17 +71,15 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world } if (world_size <= 4) { - if (message_size < 1 * 1000 * 1000) { + if (message_size < 1 * 1024 * 1024) { return AllReduceStrategyType::ONESHOT; } - assert(false && "Custom allreduce do not twoshot currently"); return AllReduceStrategyType::TWOSHOT; } - if (message_size < 500 * 1000) { + if (message_size < 512 * 1024) { return AllReduceStrategyType::ONESHOT; } - assert(false && "Custom allreduce do not twoshot currently"); return AllReduceStrategyType::TWOSHOT; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index 2a2dcebc8..59b548c77 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -71,7 +71,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); // should be gurantee in python code - assert(strategy == AllReduceStrategyType::ONESHOT); + assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT); assert(CanApplyCustomAllReduce(num_elements, dtype)); // Initialize the all-reduce kernel arguments. diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py index d6265118f..a5ce1b41d 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase): @classmethod def setUpClass(cls): random.seed(42) - cls.test_sizes = { - 2: [512, 4096, 32768, 262144, 2097152], - 4: [512, 4096, 32768, 131072], - 6: [512, 4096, 32768, 65536], - 8: [512, 4096, 32768, 65536], - } - cls.world_sizes = [2, 4, 6, 8] + cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] + cls.world_sizes = [2, 4, 8] @staticmethod def create_shared_buffer( @@ -194,7 +189,7 @@ class TestCustomAllReduce(unittest.TestCase): self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) test_loop = 10 - for sz in self.test_sizes[world_size]: + for sz in self.test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: for _ in range(test_loop): inp1 = torch.randint( @@ -216,7 +211,7 @@ class TestCustomAllReduce(unittest.TestCase): self.init_vllm_allreduce(rank, group) self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) - for sz in self.test_sizes[world_size]: + for sz in self.test_sizes: inp1 = torch.randint( 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() )