Rename files in sgl kernel to avoid nested folder structure (#4213)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2025-03-08 22:54:51 -08:00
committed by GitHub
parent ee132a4515
commit 8abf74e3c9
47 changed files with 184 additions and 199 deletions

View File

@@ -0,0 +1,180 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}

View File

@@ -0,0 +1,582 @@
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#else
#include <hip/hip_bf16.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
hipError_t e = cmd; \
if (e != hipSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace vllm {
constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal {
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
};
#ifdef USE_ROCM
struct __align__(16) RankData {
const void* ptrs[8];
};
#else
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
#endif
struct __align__(16) RankSignals {
#ifndef USE_ROCM
volatile
#endif
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) {
return __half2float(val);
}
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) {
return a += b;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) {
return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
;
}
__syncthreads();
#endif
}
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
}
if constexpr (!final_sync) __syncthreads();
#endif
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
}
template <typename P>
#ifdef USE_ROCM
DINLINE P* get_tmp_buf(Signal* sg) {
#else
DINLINE P* get_tmp_buf(volatile Signal* sg) {
#endif
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
end_sync<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
// below are device pointers
RankSignals sg_;
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(
Signal* meta,
void* rank_data,
size_t rank_data_sz,
const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Signal* rank_sg;
if (i != rank_) {
char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
} else {
rank_sg = self_sg_;
}
sg_.signals[i] = rank_sg;
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(hipIpcOpenMemHandle(
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(hipIpcMemHandle_t);
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (hipPointerGetAttribute(
&base_ptr,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
data.ptrs[i] = self;
}
}
auto d_data = d_rank_data_base_++;
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
buffers_[self] = d_data;
}
// note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(
hipStream_t stream,
T* input,
T* output,
int size,
#ifndef USE_ROCM
int threads = 512,
int block_limit = 36){
#else
int threads = 512,
int block_limit = 16) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
hipStreamCaptureStatus status;
CUDACHECK(hipStreamIsCapturing(stream, &status));
if (status == hipStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
hipLaunchKernelGGL( \
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(hipIpcCloseMemHandle(ptr));
}
}
}; // namespace vllm
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm

View File

@@ -0,0 +1,545 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include "trt_reduce_internal.cuh"
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
int4 packed;
float unpacked[4];
};
using PackedHalf = union {
int4 packed;
half2 unpacked[4];
};
template <typename T>
struct PackedOn16Bytes {};
template <>
struct PackedOn16Bytes<float> {
using Type = PackedFloat;
};
template <>
struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
};
template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
#endif
// add two 128b data
template <typename T>
inline __device__ int4 add128b(T& a, T& b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;
if (bidx == 0) {
st_flag_release(flag, signals[tidx] + offset + local_rank);
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
}
__syncthreads();
}
template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx,
int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
// Blocks check that corresponding blocks on other GPUs have also set the flag
if constexpr (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
}
__syncthreads();
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedStruct = typename PackedOn16Bytes<T>::Type;
// The source pointers. Distributed round-robin for the different warps.
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
if constexpr (COPY_INPUT) {
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
// Copy from local buffer to shareable buffer
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
}
// Sum the values from the different ranks.
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
sums.packed = add128b(sums, vals[rank]);
}
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;
// The number of elements packed into one for comms
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
T* buffers[RANKS_PER_NODE];
T* buffers_unorder[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
#endif
if constexpr (COPY_INPUT) {
// Copy all blocks from local buffer to shareable buffer
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
}
}
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
size_t const responsible_block_offset = local_offset + params.rank_offset;
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
}
// Sum the values from the different ranks.
PackedType sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
sums.packed = add128b(sums, vals[rank]);
}
// Store to the local buffer or tmp buffer
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
} else {
*reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
}
}
block_barrier<false, true>(
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// use round-robin gathering from other ranks
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
} else {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int divUp(int a, int b) {
return (a + b - 1) / b;
}
inline int roundUp(int a, int n) {
return divUp(a, n) * n;
}
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
assert(params.elts_total % elts_per_thread == 0);
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
params.elts_per_rank = params.elts_total;
break;
}
case AllReduceStrategyType::TWOSHOT: {
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
/*
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
params.elts_per_rank = params.elts_total / params.ranks_per_node;
params.rank_offset = params.local_rank * params.elts_per_rank;
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
break;
}
default:
assert(false && "Algorithm not supported here.");
}
return std::make_tuple(blocks_per_grid, threads_per_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(
AllReduceStrategyType algo,
AllReduceParams& param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
case AllReduceStrategyType::TWOSHOT: {
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
}
}
template <typename T, bool COPY_INPUT>
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
if (param.is_capturing) {
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
} else {
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}
switch (data_type) {
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
#endif
default:
assert(false && "Unsupported data type");
}
}
} // namespace trt_llm

View File

@@ -0,0 +1,226 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include "trt_reduce_internal.cuh"
#include "utils.h"
using namespace trt_llm;
using fptr_t = int64_t;
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta {
public:
AllReduceMeta(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->barrier_in = barrier_in;
this->barrier_out = barrier_out;
this->tmp_result_buffers = tmp_result_buffers;
this->rank_data_base = reinterpret_cast<RankData*>(rank_data.data_ptr());
RankData data;
for (int i = 0; i < world_size; i++) {
data.ptrs[i] = (void*)buffers[i];
}
auto d_data = this->rank_data_base++;
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
this->buffers = d_data;
}
~AllReduceMeta() {
for (auto [_, ptr] : ipc_handles_) {
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
}
}
public:
int world_size;
int rank_id;
std::vector<fptr_t> barrier_in;
std::vector<fptr_t> barrier_out;
std::vector<fptr_t> tmp_result_buffers;
int barrier_flag = 1;
RankData* buffers;
RankData* rank_data_base;
std::vector<void*> graph_unreg_buffers;
std::map<IPC_KEY, char*> ipc_handles_;
};
// Get the number of bits for a given data type.
inline int get_bits(at::ScalarType dtype) {
switch (dtype) {
case at::ScalarType::Float:
return 32;
case at::ScalarType::Half:
case at::ScalarType::BFloat16:
return 16;
default:
assert(false && "Unsupported data type");
}
}
// Check if customized all-reduce kernels can be applied.
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
// The customized all-reduce kernel has the following requirement(s).
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
fptr_t init_custom_ar(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
delete fa;
}
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto num_buffers = m->graph_unreg_buffers.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = m->graph_unreg_buffers[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) {
assert(false && "failed to get pointer attr");
}
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
std::vector<int64_t> bytes(handles.begin(), handles.end());
return std::make_pair(bytes, offsets);
}
char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
handle_bytes.emplace_back(handles[i].begin(), handles[i].end());
}
auto num_buffers = m->graph_unreg_buffers.size();
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = m->graph_unreg_buffers[i];
auto& rd = rank_data[i];
for (int j = 0; j < m->world_size; j++) {
if (j != m->rank_id) {
char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CHECK_CUDA_SUCCESS(
cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
m->rank_data_base += num_buffers;
m->graph_unreg_buffers.clear();
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto stream = c10::cuda::getCurrentCUDAStream().stream();
auto num_elements = inp.numel();
auto dtype = inp.scalar_type();
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
// should be gurantee in python code
assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT);
assert(CanApplyCustomAllReduce(num_elements, dtype));
// Initialize the all-reduce kernel arguments.
int world_size = m->world_size;
AllReduceParams params;
params.ranks_per_node = world_size;
params.rank = m->rank_id;
params.local_rank = m->rank_id;
params.local_input_buffer_ptr = inp.data_ptr();
params.local_output_buffer_ptr = out.data_ptr();
params.elts_total = inp.numel();
params.elts_size = inp.element_size();
params.barrier_flag = ++(m->barrier_flag);
cudaStreamCaptureStatus status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
params.is_capturing = (status == cudaStreamCaptureStatusActive);
if (params.is_capturing) {
params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size();
m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr);
} else {
params.peer_comm_buffer_ptrs = m->buffers;
}
for (int i = 0; i < world_size; ++i) {
params.tmp_result_buffers[i] = reinterpret_cast<uint32_t*>(m->tmp_result_buffers[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
}
auto data_type = out.scalar_type();
trtCustomAllReduce(params, data_type, strategy, stream);
}