adapt tensorrt llm custom all reduce to sgl-kernel (#2481)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -1,47 +1,75 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# Basic settings
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
|
||||
find_package(PythonInterp 3 REQUIRED)
|
||||
find_package(PythonLibs 3 REQUIRED)
|
||||
# Set CUDA architectures
|
||||
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90")
|
||||
message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
# Find PyTorch
|
||||
execute_process(
|
||||
COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
|
||||
OUTPUT_VARIABLE TORCH_CMAKE_PATH
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}")
|
||||
message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}")
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
|
||||
# Warp Reduce library
|
||||
add_library(warp_reduce SHARED
|
||||
src/sgl-kernel/csrc/warp_reduce.cc
|
||||
src/sgl-kernel/csrc/warp_reduce_kernel.cu
|
||||
)
|
||||
|
||||
target_include_directories(warp_reduce PRIVATE
|
||||
${CUDA_INCLUDE_DIRS}
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
target_include_directories(warp_reduce
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
|
||||
${CUDA_INCLUDE_DIRS}
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(warp_reduce PRIVATE
|
||||
${TORCH_LIBRARIES}
|
||||
${PYTHON_LIBRARIES}
|
||||
target_link_libraries(warp_reduce
|
||||
PRIVATE
|
||||
${TORCH_LIBRARIES}
|
||||
Python3::Python
|
||||
)
|
||||
|
||||
set_target_properties(warp_reduce PROPERTIES
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
# TRT Reduce library
|
||||
add_library(trt_reduce SHARED
|
||||
src/sgl-kernel/csrc/trt_reduce.cc
|
||||
src/sgl-kernel/csrc/trt_reduce_internal.cu
|
||||
src/sgl-kernel/csrc/trt_reduce_kernel.cu
|
||||
)
|
||||
|
||||
target_include_directories(trt_reduce
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
|
||||
${CUDA_INCLUDE_DIRS}
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(trt_reduce
|
||||
PRIVATE
|
||||
${TORCH_LIBRARIES}
|
||||
Python3::Python
|
||||
)
|
||||
|
||||
# Set common properties for both libraries
|
||||
foreach(target warp_reduce trt_reduce)
|
||||
set_target_properties(${target} PROPERTIES
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_RESOLVE_DEVICE_SYMBOLS ON
|
||||
PREFIX ""
|
||||
SUFFIX ".so"
|
||||
)
|
||||
endforeach()
|
||||
|
||||
@@ -10,7 +10,7 @@ install:
|
||||
@pip install -e .
|
||||
|
||||
build:
|
||||
@python3 setup.py bdist_wheel
|
||||
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel
|
||||
|
||||
clean:
|
||||
@rm -rf build dist *.egg-info
|
||||
@@ -19,4 +19,4 @@ test:
|
||||
@pytest tests/
|
||||
|
||||
format:
|
||||
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
||||
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.0.2.post4"
|
||||
version = "0.0.2.post5"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
@@ -84,7 +84,31 @@ setup(
|
||||
},
|
||||
libraries=["c10", "torch", "torch_python"],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
|
||||
)
|
||||
),
|
||||
CUDAExtension(
|
||||
"sgl_kernel.ops.custom_reduce_cuda",
|
||||
[
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce.cc",
|
||||
],
|
||||
extra_compile_args={
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_75,code=sm_75",
|
||||
"-gencode=arch=compute_80,code=sm_80",
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
"cxx": ["-O3"],
|
||||
},
|
||||
libraries=["c10", "torch", "torch_python"],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
install_requires=["torch"],
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from .ops import warp_reduce
|
||||
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce
|
||||
|
||||
__all__ = ["warp_reduce"]
|
||||
__all__ = [
|
||||
"warp_reduce",
|
||||
"init_custom_reduce",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
]
|
||||
|
||||
13
sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Normal file
13
sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Normal file
@@ -0,0 +1,13 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
|
||||
void dispose(fptr_t _fa);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
m.def("dispose", &dispose, "dispose custom allreduce meta");
|
||||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
|
||||
}
|
||||
282
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Normal file
282
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Normal file
@@ -0,0 +1,282 @@
|
||||
// reference:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <tuple>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
|
||||
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
|
||||
uint32_t flag;
|
||||
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
namespace trt_llm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type Converter that packs data format to 128 bits data type
|
||||
//
|
||||
using PackedFloat = union {
|
||||
int4 packed;
|
||||
float unpacked[4];
|
||||
};
|
||||
|
||||
using PackedHalf = union {
|
||||
int4 packed;
|
||||
half2 unpacked[4];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PackedOn16Bytes {};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<float> {
|
||||
using Type = PackedFloat;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<half> {
|
||||
using Type = PackedHalf;
|
||||
};
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
using PackedBFloat16 = union {
|
||||
int4 packed;
|
||||
__nv_bfloat162 unpacked[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<__nv_bfloat16> {
|
||||
using Type = PackedBFloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
template <typename T>
|
||||
inline __device__ int4 add128b(T& a, T& b) {
|
||||
T c;
|
||||
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
|
||||
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
|
||||
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
|
||||
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
|
||||
size_t const world_size, int const tidx, int const bidx) {
|
||||
// After this function, at least one block in each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, world_size]
|
||||
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
|
||||
|
||||
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
|
||||
size_t offset = (flag % 2) ? world_size : 0;
|
||||
|
||||
if (bidx == 0) {
|
||||
st_flag_release(flag, signals[tidx] + offset + local_rank);
|
||||
}
|
||||
|
||||
// All blocks check that corresponding block 0 on other GPUs have set the flag
|
||||
// No deadlock because block #0 is always the first block started
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
|
||||
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// GPU 0 | B0 | B1 | B2 | B3 |
|
||||
// GPU 1 | B0 | B1 | B2 | B3 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
|
||||
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunk is it responsible for into all other GPUs:
|
||||
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
|
||||
// 2. block sync so the block is shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
T const* buffers[RANKS_PER_NODE];
|
||||
|
||||
// Start and end offsets of the thread
|
||||
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
}
|
||||
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedStruct vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedStruct sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
|
||||
// Always reduce from rank 0 to ensure stable reduce order.
|
||||
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
|
||||
sums.packed = add128b(sums, vals[ii]);
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline int divUp(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
inline int roundUp(int a, int n) {
|
||||
return divUp(a, n) * n;
|
||||
}
|
||||
|
||||
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
|
||||
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
assert(params.elts_total % elts_per_thread == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
|
||||
params.elts_per_rank = params.elts_total;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Algorithm not supported here.");
|
||||
}
|
||||
|
||||
return std::make_tuple(blocks_per_grid, threads_per_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
|
||||
cudaStream_t stream) {
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
|
||||
void* local_inp_buffer = param.local_input_buffer_ptr;
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
|
||||
CHECK_CUDA_SUCCESS(cudaGetLastError());
|
||||
|
||||
size_t elts_per_thread = 16 / sizeof(T);
|
||||
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
|
||||
switch (param.ranks_per_node) {
|
||||
case 2:
|
||||
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 4:
|
||||
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 6:
|
||||
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 8:
|
||||
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(cudaGetLastError());
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
|
||||
cudaStream_t stream) {
|
||||
if (params.elts_total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (data_type) {
|
||||
case at::ScalarType::Float:
|
||||
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
|
||||
break;
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16:
|
||||
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
} // namespace trt_llm
|
||||
91
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Normal file
91
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Normal file
@@ -0,0 +1,91 @@
|
||||
// reference:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace trt_llm {
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
|
||||
|
||||
enum class AllReduceStrategyType : int8_t {
|
||||
RING = 0,
|
||||
ONESHOT = 1,
|
||||
TWOSHOT = 2,
|
||||
AUTO = 3,
|
||||
};
|
||||
|
||||
struct AllReduceParams {
|
||||
size_t elts_size;
|
||||
size_t elts_total;
|
||||
size_t elts_per_rank;
|
||||
size_t elts_per_block;
|
||||
size_t rank_offset;
|
||||
size_t ranks_per_node, rank, local_rank;
|
||||
uint32_t barrier_flag;
|
||||
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
|
||||
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
|
||||
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
||||
void* local_input_buffer_ptr;
|
||||
void* local_output_buffer_ptr;
|
||||
};
|
||||
|
||||
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
|
||||
if (world_size <= 2) {
|
||||
return 16 * 1000 * 1000;
|
||||
}
|
||||
return 8 * 1000 * 1000;
|
||||
}
|
||||
|
||||
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
|
||||
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
|
||||
|
||||
if (message_size > maxWorkspaceSize) {
|
||||
assert(false && "Custom allreduce do not ring currently");
|
||||
return AllReduceStrategyType::RING;
|
||||
}
|
||||
|
||||
if (world_size <= 2) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
|
||||
if (world_size <= 4) {
|
||||
if (message_size < 1 * 1000 * 1000) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
assert(false && "Custom allreduce do not twoshot currently");
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
|
||||
if (message_size < 500 * 1000) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
assert(false && "Custom allreduce do not twoshot currently");
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace trt_llm
|
||||
102
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
Normal file
102
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
Normal file
@@ -0,0 +1,102 @@
|
||||
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
|
||||
using namespace trt_llm;
|
||||
|
||||
using fptr_t = int64_t;
|
||||
|
||||
class AllReduceMeta {
|
||||
public:
|
||||
AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
|
||||
this->rank_id = (int)rank_id;
|
||||
this->world_size = (int)world_size;
|
||||
this->buffers = buffers;
|
||||
this->barrier_in = barrier_in;
|
||||
this->barrier_out = barrier_out;
|
||||
}
|
||||
|
||||
public:
|
||||
int world_size;
|
||||
int rank_id;
|
||||
std::vector<fptr_t> buffers;
|
||||
std::vector<fptr_t> barrier_in;
|
||||
std::vector<fptr_t> barrier_out;
|
||||
int barrier_flag = 1;
|
||||
};
|
||||
|
||||
// Get the number of bits for a given data type.
|
||||
inline int get_bits(at::ScalarType dtype) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return 32;
|
||||
case at::ScalarType::Half:
|
||||
case at::ScalarType::BFloat16:
|
||||
return 16;
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
// Check if customized all-reduce kernels can be applied.
|
||||
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
|
||||
// The customized all-reduce kernel has the following requirement(s).
|
||||
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
|
||||
}
|
||||
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
|
||||
auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out);
|
||||
return (fptr_t)m;
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
auto num_elements = inp.numel();
|
||||
auto dtype = inp.scalar_type();
|
||||
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
|
||||
|
||||
// should be gurantee in python code
|
||||
assert(strategy == AllReduceStrategyType::ONESHOT);
|
||||
assert(CanApplyCustomAllReduce(num_elements, dtype));
|
||||
|
||||
// Initialize the all-reduce kernel arguments.
|
||||
int world_size = m->world_size;
|
||||
|
||||
AllReduceParams params;
|
||||
params.ranks_per_node = world_size;
|
||||
params.rank = m->rank_id;
|
||||
params.local_rank = m->rank_id;
|
||||
params.local_input_buffer_ptr = inp.data_ptr();
|
||||
params.local_output_buffer_ptr = out.data_ptr();
|
||||
params.elts_total = inp.numel();
|
||||
params.elts_size = inp.element_size();
|
||||
params.barrier_flag = ++(m->barrier_flag);
|
||||
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_comm_buffer_ptrs[i] = reinterpret_cast<void*>(m->buffers[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
|
||||
}
|
||||
|
||||
auto data_type = out.scalar_type();
|
||||
trtCustomAllReduce(params, data_type, strategy, stream);
|
||||
}
|
||||
36
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
Normal file
36
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
*
|
||||
* @param message The error char array used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(const char* message) : std::runtime_error(message) {}
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message` string.
|
||||
*
|
||||
* @param message The `std::string` used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
|
||||
};
|
||||
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
std::stringstream _message; \
|
||||
auto s = cudaGetErrorString(e); \
|
||||
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
|
||||
throw cuda_error(_message.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CUDA_INPUT(x) \
|
||||
CHECK_IS_CUDA(x); \
|
||||
CHECK_IS_CONTIGUOUS(x)
|
||||
@@ -1,15 +1,11 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor warp_reduce(torch::Tensor input) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_CUDA_INPUT(input);
|
||||
return warp_reduce_cuda(input);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
from .custom_reduce_cuda import all_reduce as _all_reduce
|
||||
from .custom_reduce_cuda import dispose as _dispose
|
||||
from .custom_reduce_cuda import init_custom_ar as _init_custom_ar
|
||||
from .warp_reduce_cuda import reduce as _reduce
|
||||
|
||||
|
||||
def warp_reduce(input_tensor):
|
||||
return _reduce(input_tensor)
|
||||
|
||||
|
||||
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
|
||||
return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out)
|
||||
|
||||
|
||||
def custom_dispose(fa):
|
||||
_dispose(fa)
|
||||
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
_all_reduce(fa, inp, out)
|
||||
|
||||
248
sgl-kernel/tests/test_trt_reduce.py
Normal file
248
sgl-kernel/tests/test_trt_reduce.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
import unittest
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int,
|
||||
cls: Any,
|
||||
test_target: Any,
|
||||
) -> None:
|
||||
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
# NOTE: We need to set working_dir for distributed tests,
|
||||
# otherwise we may get import errors on ray workers
|
||||
ray.init(log_to_driver=True)
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(world_size):
|
||||
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
random.seed(42)
|
||||
cls.test_sizes = {
|
||||
2: [512, 4096, 32768, 262144, 2097152],
|
||||
4: [512, 4096, 32768, 131072],
|
||||
6: [512, 4096, 32768, 65536],
|
||||
8: [512, 4096, 32768, 65536],
|
||||
}
|
||||
cls.world_sizes = [2, 4, 6, 8]
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
if world_size > torch.cuda.device_count():
|
||||
continue
|
||||
multi_process_parallel(world_size, self, self.correctness)
|
||||
|
||||
def test_performance(self):
|
||||
for world_size in self.world_sizes:
|
||||
if world_size > torch.cuda.device_count():
|
||||
continue
|
||||
multi_process_parallel(world_size, self, self.performance)
|
||||
|
||||
def init_custom_allreduce(self, rank, world_size, group):
|
||||
import sgl_kernel
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
|
||||
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
self.buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
|
||||
def custom_allreduce(self, inp, out):
|
||||
import sgl_kernel
|
||||
|
||||
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
|
||||
|
||||
def free_custom_allreduce(self, group):
|
||||
import sgl_kernel
|
||||
|
||||
self.free_shared_buffer(self.buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
||||
sgl_kernel.ops.custom_dispose(self.custom_ptr)
|
||||
|
||||
def init_vllm_allreduce(self, rank, group):
|
||||
self.vllm_rank = rank
|
||||
self.vllm_max_size = 8 * 1024 * 1024
|
||||
self.vllm_meta_ptrs = self.create_shared_buffer(
|
||||
vllm_ops.meta_size() + self.vllm_max_size, group=group
|
||||
)
|
||||
self.vllm_buffer_ptrs = self.create_shared_buffer(
|
||||
self.vllm_max_size, group=group
|
||||
)
|
||||
self.vllm_rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
|
||||
)
|
||||
self.vllm_ptr = vllm_ops.init_custom_ar(
|
||||
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
|
||||
)
|
||||
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
|
||||
|
||||
def vllm_allreduce(self, inp, out):
|
||||
vllm_ops.all_reduce(
|
||||
self.vllm_ptr,
|
||||
inp,
|
||||
out,
|
||||
self.vllm_buffer_ptrs[self.vllm_rank],
|
||||
self.vllm_max_size,
|
||||
)
|
||||
|
||||
def free_vllm_allreduce(self, group):
|
||||
vllm_ops.dispose(self.vllm_ptr)
|
||||
self.free_shared_buffer(self.vllm_meta_ptrs, group)
|
||||
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
|
||||
|
||||
@staticmethod
|
||||
def init_distributed_env(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = [i for i in range(world_size)]
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
return group
|
||||
|
||||
# compare result with torch.distributed
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def correctness(self, world_size, rank, distributed_init_port):
|
||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
||||
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
test_loop = 10
|
||||
for sz in self.test_sizes[world_size]:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
out1 = torch.empty_like(inp1)
|
||||
self.custom_allreduce(inp1, out1)
|
||||
|
||||
dist.all_reduce(inp1, group=group)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
|
||||
self.free_custom_allreduce(group)
|
||||
|
||||
# compare performance with vllm
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def performance(self, world_size, rank, distributed_init_port):
|
||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
||||
|
||||
self.init_vllm_allreduce(rank, group)
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
for sz in self.test_sizes[world_size]:
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
out1 = torch.empty_like(inp1)
|
||||
test_loop = 5000
|
||||
start = time.time()
|
||||
for _ in range(test_loop):
|
||||
self.custom_allreduce(inp1, out1)
|
||||
elapse_custom = time.time() - start
|
||||
|
||||
start = time.time()
|
||||
for _ in range(test_loop):
|
||||
self.vllm_allreduce(inp1, out1)
|
||||
elapse_vllm = time.time() - start
|
||||
|
||||
if rank == 0:
|
||||
logger.warning(
|
||||
f"test_size = {sz}, world_size = {world_size}, "
|
||||
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}us,"
|
||||
f"custom time = {elapse_custom * 1000 / test_loop:.4f}us"
|
||||
)
|
||||
|
||||
self.free_custom_allreduce(group)
|
||||
self.free_vllm_allreduce(group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user