sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)
This commit is contained in:
@@ -157,8 +157,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/trt_reduce_internal.cu"
|
||||
"csrc/allreduce/trt_reduce_kernel.cu"
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
|
||||
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
@@ -0,0 +1,137 @@
|
||||
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t
|
||||
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
vllm::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(
|
||||
stream, reinterpret_cast<float*>(reg_buffer), reinterpret_cast<float*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(
|
||||
stream, reinterpret_cast<half*>(reg_buffer), reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() {
|
||||
return sizeof(vllm::Signal);
|
||||
}
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
@@ -0,0 +1,489 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val)
|
||||
;
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val)
|
||||
;
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
@@ -1,532 +0,0 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// reference:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <tuple>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
|
||||
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
|
||||
uint32_t flag;
|
||||
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
|
||||
uint32_t flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
namespace trt_llm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type Converter that packs data format to 128 bits data type
|
||||
//
|
||||
using PackedFloat = union {
|
||||
int4 packed;
|
||||
float unpacked[4];
|
||||
};
|
||||
|
||||
using PackedHalf = union {
|
||||
int4 packed;
|
||||
half2 unpacked[4];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PackedOn16Bytes {};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<float> {
|
||||
using Type = PackedFloat;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<half> {
|
||||
using Type = PackedHalf;
|
||||
};
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
using PackedBFloat16 = union {
|
||||
int4 packed;
|
||||
__nv_bfloat162 unpacked[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<__nv_bfloat16> {
|
||||
using Type = PackedBFloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
template <typename T>
|
||||
inline __device__ int4 add128b(T& a, T& b) {
|
||||
T c;
|
||||
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
|
||||
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
|
||||
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
|
||||
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx) {
|
||||
// After this function, at least one block in each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, world_size]
|
||||
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
|
||||
|
||||
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
|
||||
size_t offset = (flag % 2) ? world_size : 0;
|
||||
|
||||
if (bidx == 0) {
|
||||
st_flag_release(flag, signals[tidx] + offset + local_rank);
|
||||
}
|
||||
|
||||
// All blocks check that corresponding block 0 on other GPUs have set the flag
|
||||
// No deadlock because block #0 is always the first block started
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <bool start, bool need_fence = false>
|
||||
__inline__ __device__ void block_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx,
|
||||
int const grid_size) {
|
||||
if constexpr (!start) {
|
||||
__syncthreads();
|
||||
}
|
||||
// After this function, the block of id == bidx of each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
|
||||
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
|
||||
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
|
||||
|
||||
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
|
||||
uint32_t flag_block_offset = world_size + bidx * world_size;
|
||||
|
||||
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
|
||||
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
|
||||
// Blocks check that corresponding blocks on other GPUs have also set the flag
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag) {
|
||||
}
|
||||
} else {
|
||||
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
|
||||
while (ld_flag_volatile(peer_barrier_d) != flag) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (start || need_fence) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) {
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// GPU 0 | B0 | B1 | B2 | B3 |
|
||||
// GPU 1 | B0 | B1 | B2 | B3 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
|
||||
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunk is it responsible for into all other GPUs:
|
||||
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
|
||||
// 2. block sync so the block is shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
int const grid_size = gridDim.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
|
||||
// Start and end offsets of the thread
|
||||
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
if constexpr (COPY_INPUT) {
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
// Copy from local buffer to shareable buffer
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
|
||||
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
|
||||
}
|
||||
}
|
||||
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedStruct vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedStruct sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
|
||||
// Always reduce from rank 0 to ensure stable reduce order.
|
||||
sums.packed = add128b(sums, vals[rank]);
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
}
|
||||
block_barrier<false>(
|
||||
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
|
||||
// GPU 0 | B0 | B1 | B0 | B1 |
|
||||
// GPU 1 | B0 | B1 | B0 | B1 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
|
||||
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
|
||||
// part (the first half of the message, see GPU responsibility row above)
|
||||
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
|
||||
// where GPU 1 is responsible: the second half of the message.
|
||||
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
|
||||
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
|
||||
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
// to be read.
|
||||
//
|
||||
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
|
||||
// However, it's only responsible for the summation of a single chunk.
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
|
||||
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
|
||||
// 2. block sync so the blocks have been shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
// 4. block barrier (corresponding blocks have finished reduction)
|
||||
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
|
||||
// written at index 0 of 2nd dim)
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
int const grid_size = gridDim.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int PACKED_ELTS = 16 / sizeof(T);
|
||||
using PackedType = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
|
||||
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
||||
|
||||
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
|
||||
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
T* buffers[RANKS_PER_NODE];
|
||||
T* buffers_unorder[RANKS_PER_NODE];
|
||||
int ranks[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
// A mapping of the ranks to scatter reads as much as possible
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
ranks[ii] = rank;
|
||||
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
|
||||
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
|
||||
}
|
||||
|
||||
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
if constexpr (COPY_INPUT) {
|
||||
// Copy all blocks from local buffer to shareable buffer
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total) {
|
||||
continue;
|
||||
}
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
size_t const responsible_block_offset = local_offset + params.rank_offset;
|
||||
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
|
||||
// Always reduce from rank 0 to ensure stable reduce order.
|
||||
sums.packed = add128b(sums, vals[rank]);
|
||||
}
|
||||
|
||||
// Store to the local buffer or tmp buffer
|
||||
if constexpr (COPY_INPUT) {
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
|
||||
} else {
|
||||
*reinterpret_cast<int4*>(¶ms.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
block_barrier<false, true>(
|
||||
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Gather all needed elts from other intra-node ranks
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
// use round-robin gathering from other ranks
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total) {
|
||||
continue;
|
||||
}
|
||||
if constexpr (COPY_INPUT) {
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
|
||||
} else {
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4*>(¶ms.tmp_result_buffers[ranks[ii]][offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline int divUp(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
inline int roundUp(int a, int n) {
|
||||
return divUp(a, n) * n;
|
||||
}
|
||||
|
||||
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
|
||||
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
assert(params.elts_total % elts_per_thread == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
|
||||
params.elts_per_rank = params.elts_total;
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::TWOSHOT: {
|
||||
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
|
||||
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
params.elts_per_rank = params.elts_total / params.ranks_per_node;
|
||||
params.rank_offset = params.local_rank * params.elts_per_rank;
|
||||
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Algorithm not supported here.");
|
||||
}
|
||||
|
||||
return std::make_tuple(blocks_per_grid, threads_per_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
|
||||
void dispatchARKernels(
|
||||
AllReduceStrategyType algo,
|
||||
AllReduceParams& param,
|
||||
int blocks_per_grid,
|
||||
int threads_per_block,
|
||||
cudaStream_t stream) {
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::TWOSHOT: {
|
||||
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool COPY_INPUT>
|
||||
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
|
||||
size_t elts_per_thread = 16 / sizeof(T);
|
||||
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
|
||||
switch (param.ranks_per_node) {
|
||||
case 2:
|
||||
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 4:
|
||||
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 6:
|
||||
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 8:
|
||||
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
if (param.is_capturing) {
|
||||
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
|
||||
} else {
|
||||
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(cudaGetLastError());
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(
|
||||
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
if (params.elts_total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (data_type) {
|
||||
case at::ScalarType::Float:
|
||||
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
|
||||
break;
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16:
|
||||
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
} // namespace trt_llm
|
||||
@@ -1,226 +0,0 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace trt_llm;
|
||||
|
||||
using fptr_t = int64_t;
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
|
||||
class AllReduceMeta {
|
||||
public:
|
||||
AllReduceMeta(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
this->rank_id = (int)rank_id;
|
||||
this->world_size = (int)world_size;
|
||||
this->barrier_in = barrier_in;
|
||||
this->barrier_out = barrier_out;
|
||||
this->tmp_result_buffers = tmp_result_buffers;
|
||||
|
||||
this->rank_data_base = reinterpret_cast<RankData*>(rank_data.data_ptr());
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
data.ptrs[i] = (void*)buffers[i];
|
||||
}
|
||||
auto d_data = this->rank_data_base++;
|
||||
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
this->buffers = d_data;
|
||||
}
|
||||
|
||||
~AllReduceMeta() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
int world_size;
|
||||
int rank_id;
|
||||
std::vector<fptr_t> barrier_in;
|
||||
std::vector<fptr_t> barrier_out;
|
||||
std::vector<fptr_t> tmp_result_buffers;
|
||||
int barrier_flag = 1;
|
||||
RankData* buffers;
|
||||
RankData* rank_data_base;
|
||||
std::vector<void*> graph_unreg_buffers;
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
};
|
||||
|
||||
// Get the number of bits for a given data type.
|
||||
inline int get_bits(at::ScalarType dtype) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return 32;
|
||||
case at::ScalarType::Half:
|
||||
case at::ScalarType::BFloat16:
|
||||
return 16;
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
// Check if customized all-reduce kernels can be applied.
|
||||
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
|
||||
// The customized all-reduce kernel has the following requirement(s).
|
||||
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
|
||||
}
|
||||
|
||||
fptr_t init_custom_ar(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
|
||||
return (fptr_t)m;
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
auto num_buffers = m->graph_unreg_buffers.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = m->graph_unreg_buffers[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) {
|
||||
assert(false && "failed to get pointer attr");
|
||||
}
|
||||
|
||||
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
std::vector<int64_t> bytes(handles.begin(), handles.end());
|
||||
return std::make_pair(bytes, offsets);
|
||||
}
|
||||
|
||||
char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
|
||||
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
std::vector<std::string> handle_bytes;
|
||||
handle_bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
handle_bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
auto num_buffers = m->graph_unreg_buffers.size();
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = m->graph_unreg_buffers[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < m->world_size; j++) {
|
||||
if (j != m->rank_id) {
|
||||
char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
|
||||
m->rank_data_base += num_buffers;
|
||||
m->graph_unreg_buffers.clear();
|
||||
}
|
||||
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
auto num_elements = inp.numel();
|
||||
auto dtype = inp.scalar_type();
|
||||
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
|
||||
|
||||
// should be gurantee in python code
|
||||
assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT);
|
||||
assert(CanApplyCustomAllReduce(num_elements, dtype));
|
||||
|
||||
// Initialize the all-reduce kernel arguments.
|
||||
int world_size = m->world_size;
|
||||
|
||||
AllReduceParams params;
|
||||
params.ranks_per_node = world_size;
|
||||
params.rank = m->rank_id;
|
||||
params.local_rank = m->rank_id;
|
||||
params.local_input_buffer_ptr = inp.data_ptr();
|
||||
params.local_output_buffer_ptr = out.data_ptr();
|
||||
params.elts_total = inp.numel();
|
||||
params.elts_size = inp.element_size();
|
||||
params.barrier_flag = ++(m->barrier_flag);
|
||||
|
||||
cudaStreamCaptureStatus status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
|
||||
params.is_capturing = (status == cudaStreamCaptureStatusActive);
|
||||
if (params.is_capturing) {
|
||||
params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size();
|
||||
m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr);
|
||||
} else {
|
||||
params.peer_comm_buffer_ptrs = m->buffers;
|
||||
}
|
||||
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.tmp_result_buffers[i] = reinterpret_cast<uint32_t*>(m->tmp_result_buffers[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
|
||||
}
|
||||
|
||||
auto data_type = out.scalar_type();
|
||||
trtCustomAllReduce(params, data_type, strategy, stream);
|
||||
}
|
||||
@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
m.def("dispose", &dispose);
|
||||
m.def("meta_size", &meta_size);
|
||||
m.def("register_buffer", ®ister_buffer);
|
||||
|
||||
m.def(
|
||||
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||
"barrier_in, int[] barrier_out) -> int");
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
|
||||
@@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
@@ -63,18 +64,14 @@ void register_graph_buffers(
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
#else
|
||||
// TRTLLM custom allreduce
|
||||
fptr_t init_custom_ar(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out);
|
||||
// custom allreduce
|
||||
fptr_t
|
||||
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
void dispose(fptr_t _fa);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
int64_t meta_size();
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// reference:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace trt_llm {
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
||||
|
||||
enum class AllReduceStrategyType : int8_t {
|
||||
RING = 0,
|
||||
ONESHOT = 1,
|
||||
TWOSHOT = 2,
|
||||
AUTO = 3,
|
||||
};
|
||||
|
||||
struct RankData {
|
||||
void* ptrs[MAX_RANKS_PER_NODE];
|
||||
};
|
||||
|
||||
struct AllReduceParams {
|
||||
size_t elts_size;
|
||||
size_t elts_total;
|
||||
size_t elts_per_rank;
|
||||
size_t elts_per_block;
|
||||
size_t rank_offset;
|
||||
size_t ranks_per_node, rank, local_rank;
|
||||
uint32_t barrier_flag;
|
||||
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
|
||||
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
|
||||
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
|
||||
RankData* peer_comm_buffer_ptrs;
|
||||
void* local_input_buffer_ptr;
|
||||
void* local_output_buffer_ptr;
|
||||
bool is_capturing;
|
||||
};
|
||||
|
||||
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
|
||||
if (world_size <= 2) {
|
||||
return 16 * 1024 * 1024;
|
||||
}
|
||||
return 8 * 1024 * 1024;
|
||||
}
|
||||
|
||||
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
|
||||
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
|
||||
|
||||
if (message_size > maxWorkspaceSize) {
|
||||
assert(false && "Custom allreduce do not ring currently");
|
||||
return AllReduceStrategyType::RING;
|
||||
}
|
||||
|
||||
if (world_size <= 2) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
|
||||
if (world_size <= 4) {
|
||||
if (message_size < 1 * 1024 * 1024) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
|
||||
if (message_size < 512 * 1024) {
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
}
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(
|
||||
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
|
||||
|
||||
} // namespace trt_llm
|
||||
@@ -50,28 +50,38 @@ if torch.version.hip is not None:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
buffers,
|
||||
tmp_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose.default(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernel.all_reduce.default(fa, inp, out)
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce.default(
|
||||
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
|
||||
)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = list(range(world_size))
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
buffer_ptrs = None
|
||||
tmp_result_buffer_ptrs = None
|
||||
barrier_in_ptrs = None
|
||||
barrier_out_ptrs = None
|
||||
custom_ptr = None
|
||||
|
||||
try:
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
max_size = 8192 * 1024
|
||||
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
custom_ops.meta_size() + max_size, group=group
|
||||
)
|
||||
|
||||
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
|
||||
|
||||
custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
rank_data,
|
||||
buffer_ptrs,
|
||||
tmp_result_buffer_ptrs,
|
||||
barrier_in_ptrs,
|
||||
barrier_out_ptrs,
|
||||
)
|
||||
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
|
||||
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
|
||||
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
|
||||
custom_ops.custom_reduce(custom_ptr, inp1, out1)
|
||||
custom_ops.all_reduce(
|
||||
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
|
||||
)
|
||||
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
|
||||
@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
if custom_ptr is not None:
|
||||
custom_ops.custom_dispose(custom_ptr)
|
||||
custom_ops.dispose(custom_ptr)
|
||||
if buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
|
||||
if tmp_result_buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
|
||||
if barrier_in_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
|
||||
if barrier_out_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
|
||||
if meta_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
|
||||
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
@@ -122,7 +98,18 @@ def multi_process_parallel(
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
1048576,
|
||||
2097152,
|
||||
]
|
||||
world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
Reference in New Issue
Block a user