sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
@@ -0,0 +1,137 @@
|
||||
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t
|
||||
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
sglang::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(
|
||||
stream, reinterpret_cast<float*>(reg_buffer), reinterpret_cast<float*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(
|
||||
stream, reinterpret_cast<half*>(reg_buffer), reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() {
|
||||
return sizeof(sglang::Signal);
|
||||
}
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
@@ -0,0 +1,489 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace sglang {
|
||||
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val)
|
||||
;
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val)
|
||||
;
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace sglang
|
||||
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
@@ -0,0 +1,180 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/hip/Exceptions.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
|
||||
// fake pointer type, must match fptr_t type in ops.h
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
hipIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
hipStream_t stream) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, hipMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(sglang::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handles =
|
||||
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||
return {handles, std::move(offsets)};
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
|
||||
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
|
||||
inp.data_ptr()));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
torch::Tensor allocate_meta_buffer(int64_t size) {
|
||||
auto device_index = c10::hip::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(hipStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(torch::kI8)
|
||||
.device(torch::kCUDA, device_index);
|
||||
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> get_device_bdf(int dev) {
|
||||
char busIdStr[] = "0000:00:00.0";
|
||||
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
|
||||
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
|
||||
bdf.resize(bdf.size() - 1); // remove trailing NULL
|
||||
return bdf;
|
||||
}
|
||||
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
@@ -0,0 +1,582 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
hipError_t e = cmd; \
|
||||
if (e != hipSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace sglang {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
struct __align__(16) RankData {
|
||||
const void* ptrs[8];
|
||||
};
|
||||
#else
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
#endif
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__hip_atomic_load(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__HIP_MEMORY_SCOPE_AGENT) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
#ifdef USE_ROCM
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
#else
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
#endif
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal* meta,
|
||||
void* rank_data,
|
||||
size_t rank_data_sz,
|
||||
const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(hipIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(hipIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (hipPointerGetAttribute(
|
||||
&base_ptr,
|
||||
#ifdef USE_ROCM
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#else
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#endif
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
hipStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int size,
|
||||
#ifndef USE_ROCM
|
||||
int threads = 512,
|
||||
int block_limit = 36){
|
||||
#else
|
||||
int threads = 512,
|
||||
int block_limit = 16) {
|
||||
#endif
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
hipStreamCaptureStatus status;
|
||||
CUDACHECK(hipStreamIsCapturing(stream, &status));
|
||||
if (status == hipStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = ::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL( \
|
||||
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
}; // namespace sglang
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace sglang
|
||||
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,140 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
enum MscclContextSelection {
|
||||
MSCCL1NODELL = 1,
|
||||
MSCCL2NODELL = 2,
|
||||
};
|
||||
|
||||
class MscclContext {
|
||||
public:
|
||||
MscclContextSelection selection_;
|
||||
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
|
||||
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
|
||||
MscclContext(MscclContextSelection selection) : selection_(selection) {}
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
if (selection_ == MSCCL1NODELL) {
|
||||
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
} else if (selection_ == MSCCL2NODELL) {
|
||||
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
|
||||
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
|
||||
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Function to convert vector of int32_t back to array of uint8_t
|
||||
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
|
||||
mscclpp::UniqueId unique_id;
|
||||
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
|
||||
return unique_id;
|
||||
}
|
||||
|
||||
torch::Tensor mscclpp_generate_unique_id() {
|
||||
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
return _unique_id2tensor(unique_id);
|
||||
}
|
||||
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
const int64_t rank,
|
||||
const int64_t world_size,
|
||||
torch::Tensor& scratch,
|
||||
torch::Tensor& put_buffer,
|
||||
const int64_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib,
|
||||
const int64_t context_selection) {
|
||||
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
|
||||
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
|
||||
if (context_selection == MSCCL1NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
|
||||
} else if (context_selection == MSCCL2NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
|
||||
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
|
||||
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
uid,
|
||||
rank,
|
||||
world_size,
|
||||
scratch_ptr,
|
||||
scratch_bytes,
|
||||
put_buffer_ptr,
|
||||
put_buffer_bytes,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
} else {
|
||||
throw std::runtime_error("invalid context selection");
|
||||
}
|
||||
return (fptr_t)context_ptr;
|
||||
}
|
||||
|
||||
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
|
||||
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
context->allreduce<float>(
|
||||
stream,
|
||||
reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
context->allreduce<half>(
|
||||
stream,
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
context->allreduce<__nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
@@ -0,0 +1,779 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
#pragma once
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#else
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
// comment this for test_mscclpp_allreduce.cu
|
||||
#include "utils.h"
|
||||
|
||||
namespace sglang {
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
|
||||
|
||||
template <typename To, typename From>
|
||||
__forceinline__ __device__ To bit_cast(const From& src) {
|
||||
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
|
||||
|
||||
union {
|
||||
From f;
|
||||
To t;
|
||||
} u;
|
||||
u.f = src;
|
||||
return u.t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T add_elements(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
|
||||
int4 ret;
|
||||
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
|
||||
uint2 ret;
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
|
||||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors(int a, int b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_1node using LLPacket, origin allreduce2
|
||||
// -------------------------------------------------------
|
||||
|
||||
__device__ uint64_t globalFlag = 1;
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
void* resultBuff,
|
||||
int rank,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeers = worldSize - 1;
|
||||
const size_t nPkts = nelems / 2;
|
||||
const int nelemsPerRank = nelems / worldSize;
|
||||
const int nPktsPerRank = nelemsPerRank / 2;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
dst[idx] = data;
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = data.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = data.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRank * nPktsPerRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx].x = data.x;
|
||||
result[idx].y = data.y;
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_2node using LLPacket, origin allreduce5
|
||||
// -------------------------------------------------------
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
mscclpp::PortChannelDeviceHandle* portChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
TYPE* putBuff,
|
||||
TYPE* resultBuff,
|
||||
int rank,
|
||||
int nRanksPerNode,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeersInNode = nRanksPerNode - 1;
|
||||
const int nPkts = nelems / 2;
|
||||
const int nelemsPerLocalRank = nelems / nRanksPerNode;
|
||||
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
|
||||
const int localRankId = rank % nRanksPerNode;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
if (nRanksPerNode > 1) {
|
||||
memChan.putPackets(
|
||||
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
}
|
||||
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
|
||||
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
const int remoteRank = index < localRankId ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
putPkt[idx].write(data.x, data.y, flag);
|
||||
dst[idx] = data;
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// step 3. send local reduced data to remote node.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
||||
if ((flag & 63) == 0) {
|
||||
portChan.flush();
|
||||
}
|
||||
}
|
||||
// step 4. try to read the data from scratch buffer and write to local peers
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 res = dst[idx];
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
res = add_vectors<TYPE>(res, val);
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = res.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = res.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
dst[idx] = res;
|
||||
}
|
||||
|
||||
// step 5: get data result from scratch buffer
|
||||
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
|
||||
if (nRanksPerNode > 1) {
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
|
||||
idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx] = data;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
static const mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0,
|
||||
mscclpp::Transport::IB1,
|
||||
mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4,
|
||||
mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6,
|
||||
mscclpp::Transport::IB7};
|
||||
|
||||
class MscclCommGroup {
|
||||
public:
|
||||
std::shared_ptr<mscclpp::Communicator> comm_;
|
||||
const size_t rank_;
|
||||
const size_t world_size_;
|
||||
const std::vector<int64_t> rank_to_node_;
|
||||
const std::vector<int64_t> rank_to_ib_;
|
||||
MscclCommGroup(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
bootstrap->initialize(unique_id);
|
||||
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
}
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
throw std::runtime_error("you should not call allreduce of a base context");
|
||||
}
|
||||
bool is_same_node(int r1, int r2) {
|
||||
return rank_to_node_[r1] == rank_to_node_[r2];
|
||||
}
|
||||
|
||||
void make_connection(
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
|
||||
same_node_connections.clear();
|
||||
cross_node_connections.clear();
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
|
||||
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
|
||||
}
|
||||
comm_->setup();
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
same_node_connections.emplace(r, conn_futures[r].get());
|
||||
} else {
|
||||
cross_node_connections.emplace(r, conn_futures[r].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void make_memory_channels_with_scratch(
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
|
||||
}
|
||||
}
|
||||
void make_port_channels_with_scratch(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::PortChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
|
||||
|
||||
mscclpp::TransportFlags flags;
|
||||
for (const auto& [_, conn] : connections) {
|
||||
flags |= conn->transport();
|
||||
}
|
||||
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
|
||||
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
|
||||
std::unordered_map<int, size_t> memory_ids;
|
||||
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
|
||||
for (const auto& [peer, memory] : registered_memories) {
|
||||
if (peer == rank_) continue;
|
||||
memory_ids[peer] = proxyService->addMemory(memory);
|
||||
}
|
||||
for (const auto& [peer, semaphore] : semaphores) {
|
||||
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
|
||||
}
|
||||
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SemaphoreType>
|
||||
void make_semaphores(
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
|
||||
semaphores.clear();
|
||||
for (const auto& [peer, conn] : connections) {
|
||||
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
|
||||
}
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
void register_tensor_with_connections(
|
||||
void* tensor_ptr,
|
||||
size_t tensor_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
|
||||
registered_memories.clear();
|
||||
mscclpp::TransportFlags all_transports;
|
||||
for (const auto& [_, connection] : connections) {
|
||||
all_transports |= connection->transport();
|
||||
}
|
||||
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
|
||||
registered_memories[rank_] = buf_reg_mem;
|
||||
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
|
||||
for (const auto& [r, connection] : connections) {
|
||||
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
|
||||
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
|
||||
remote_mem_futures.emplace(r, remoteMemory);
|
||||
}
|
||||
comm_->setup();
|
||||
for (auto& [r, mem_feature] : remote_mem_futures) {
|
||||
registered_memories.emplace(r, mem_feature.get());
|
||||
}
|
||||
}
|
||||
|
||||
void make_device_memory_handle_base_on_new_ptr(
|
||||
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
|
||||
void* input,
|
||||
void* scratch,
|
||||
const cudaStream_t stream) {
|
||||
memory_channels.clear();
|
||||
for (const auto& [peer, channel] : old_memory_channels) {
|
||||
memory_channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < world_size_; r++) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
memory_channels_list.push_back(memory_channels[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
|
||||
device_memory_handle.data(),
|
||||
memory_channel_handlers.data(),
|
||||
memory_channel_handlers.size(),
|
||||
stream,
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl1NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
public:
|
||||
Msccl1NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl1NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl2NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
void* put_buffer_;
|
||||
const size_t put_buffer_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
|
||||
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
|
||||
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
|
||||
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
|
||||
public:
|
||||
Msccl2NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
void* put_buffer,
|
||||
const size_t put_buffer_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
put_buffer_(put_buffer),
|
||||
put_buffer_bytes_(put_buffer_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1),
|
||||
d_portHandles_(world_size - nranks_per_node) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
proxyService = std::make_shared<mscclpp::ProxyService>();
|
||||
proxyService->startProxy();
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
comm_group_->make_port_channels_with_scratch(
|
||||
proxyService,
|
||||
put_buffer_,
|
||||
put_buffer_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
cross_node_connections_,
|
||||
port_semaphores_,
|
||||
registered_port_memories_,
|
||||
port_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
std::vector<mscclpp::PortChannel> port_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
} else {
|
||||
port_channels_list.push_back(port_channels_[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
|
||||
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
|
||||
std::transform(
|
||||
port_channels_list.begin(),
|
||||
port_channels_list.end(),
|
||||
port_channel_handlers.begin(),
|
||||
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl2NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
if (proxyService) {
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans,
|
||||
d_portHandles_.data(),
|
||||
(T*)input,
|
||||
(T*)scratch_,
|
||||
(T*)put_buffer_,
|
||||
output,
|
||||
comm_group_->rank_,
|
||||
nranks_per_node_,
|
||||
comm_group_->world_size_,
|
||||
input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sglang
|
||||
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quick_all_reduce.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(
|
||||
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
||||
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
@@ -0,0 +1,633 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "quick_all_reduce_base.h"
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
struct CodecBase {
|
||||
const int thread;
|
||||
const int rank;
|
||||
const int group_leader;
|
||||
__quickreduce_device_inline__ CodecBase(int thread, int rank)
|
||||
: thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
|
||||
set_fp16_ovfl(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Default full precision codec.
|
||||
template <typename T, int world_size>
|
||||
struct CodecFP : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each thread processes atoms of f16x8_t (16B).
|
||||
static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t);
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
__quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
__builtin_nontemporal_store(data[i], send_buffer + thread);
|
||||
send_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
|
||||
*recv_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int4 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ4 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1152;
|
||||
static constexpr int kRankTileScaleOffset = 1024;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/8.0h, -1/8.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-8, -8}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
|
||||
|
||||
// {+7, +7}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
|
||||
|
||||
// {+8, +8}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00080008;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q4 into int32_t
|
||||
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q4 into f16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static constexpr uint kMask000F = 0x000F000F;
|
||||
static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q4, kHalf2_1032);
|
||||
} else {
|
||||
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int6 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ6 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1664;
|
||||
static constexpr int kRankTileQ2Offset = 1024;
|
||||
static constexpr int kRankTileScaleOffset = 1536;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/32.0h, -1/32.0h}, fp16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
|
||||
|
||||
// {1e-7, 1e-7}, fp16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-32, -32}, fp16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
|
||||
|
||||
// {+31, +31}, fp16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
|
||||
|
||||
// {+32, +32}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00200020;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q6 into int32_t + int16_t
|
||||
uint32_t q4w;
|
||||
uint16_t q2w = 0;
|
||||
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
|
||||
{
|
||||
int16_t* tw = reinterpret_cast<int16_t*>(&q);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q2w |= (tw[i] >> 4) << (i * 2);
|
||||
}
|
||||
}
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(q4w, q4w_ptr);
|
||||
__builtin_nontemporal_store(q2w, q2w_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
|
||||
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q6 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask000F = 0x000F000F;
|
||||
static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int32_t q4 = q4w & kMask000F;
|
||||
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
|
||||
q4w >>= 4;
|
||||
q2w >>= 4;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056));
|
||||
} else {
|
||||
int32_t int16_2 = q4 | (q2 << 4);
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
// That's pretty much it...
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int8 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ8 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of f16x8_t (16B),
|
||||
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 2176;
|
||||
static constexpr int kRankTileScaleOffset = 2048;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/128.0h, -1/128.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-128, -128}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
|
||||
// {+127, +127}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
|
||||
|
||||
// {+128, +128}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00800080;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q8 into int32x2_t
|
||||
int32x2_t qw;
|
||||
qw[0] = q[0] | (q[1] << 8);
|
||||
qw[1] = q[2] | (q[3] << 8);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q8 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask00FF = 0x00FF00FF;
|
||||
|
||||
// {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1024 = 0x64006400;
|
||||
|
||||
// {-1152.0, -1152.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1152 = 0xE480E480;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q8, kHalf2_1152);
|
||||
} else {
|
||||
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Twoshot All Reduce
|
||||
template <typename T, class Codec, bool cast_bf2half>
|
||||
struct AllReduceTwoshot {
|
||||
static_assert(sizeof(T) == 2);
|
||||
|
||||
static constexpr int kWorldSize = Codec::kWorldSize;
|
||||
|
||||
__device__ static void
|
||||
run(T const* __restrict__ input,
|
||||
T* __restrict__ output,
|
||||
uint32_t const N, // number of elements
|
||||
int const block, // block index
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
|
||||
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
|
||||
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
|
||||
src_offset += kAtomStride * sizeof(int32x4_t);
|
||||
if constexpr (cast_bf2half) {
|
||||
const nv_bfloat162* bf_buf = reinterpret_cast<const nv_bfloat162*>(&tA[i]);
|
||||
half2 half_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __bfloat1622float2(bf_buf[j]);
|
||||
half_buf[j] = __float22half2_rn(f);
|
||||
}
|
||||
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Phase-1A: Write segment data into the communication buffer of the target
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
// --------------------------------------------------------
|
||||
// Phase-1B: Reduce the segment data from the communication buffers.
|
||||
int32x4_t tR[Codec::kRankAtoms] = {};
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// note: we reuse tA as temp buffer here
|
||||
codec.recv(&recv_buffer, tA);
|
||||
|
||||
for (int i = 0; i < Codec::kRankAtoms; i++) {
|
||||
packed_assign_add<T>(&tR[i], &tA[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase-2: Write the reduced segment to every other rank
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, tR);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
|
||||
// Phase-2: Read the gather segments from the rank's communication buffer.
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Gather all reduced and final rank segments into tA.
|
||||
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Write the result to output.
|
||||
BufferResource dst_buffer(output, N * sizeof(T));
|
||||
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
if constexpr (cast_bf2half) {
|
||||
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
|
||||
nv_bfloat162 bf16_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __half22float2(half_buf[j]);
|
||||
bf16_buf[j] = __float22bfloat162_rn(f);
|
||||
}
|
||||
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
} else {
|
||||
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
}
|
||||
dst_offset += kAtomStride * sizeof(int32x4_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
@@ -0,0 +1,233 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "quick_all_reduce.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
|
||||
T const* A,
|
||||
T* B,
|
||||
uint32_t N,
|
||||
uint32_t num_blocks,
|
||||
int rank,
|
||||
uint8_t** dbuffer_list,
|
||||
uint32_t data_offset,
|
||||
uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() {
|
||||
return world_size;
|
||||
}
|
||||
int get_rank() {
|
||||
return rank;
|
||||
}
|
||||
bool status() {
|
||||
return initialized;
|
||||
}
|
||||
hipIpcMemHandle_t const get_handle() {
|
||||
return buffer_ipc_handle;
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(
|
||||
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(cudaGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
113
sgl-kernel/csrc/allreduce/quick_all_reduce.hip
Normal file
113
sgl-kernel/csrc/allreduce/quick_all_reduce.hip
Normal file
@@ -0,0 +1,113 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/dtk_macros.h>
|
||||
#include <ATen/hip/Exceptions.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quick_all_reduce_hip.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(
|
||||
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
||||
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
@@ -0,0 +1,318 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define __quickreduce_device_inline__ __device__ __forceinline__
|
||||
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
|
||||
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
|
||||
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
|
||||
// as per architecture.
|
||||
#if defined(__gfx942__)
|
||||
// CDNA3: Scope bits sc0, sc1
|
||||
#define MUBUF_ACQUIRE 16
|
||||
#define MUBUF_RELEASE 16
|
||||
#elif (defined(__gfx908__) || defined(__gfx90a__))
|
||||
// CDNA1 and CDNA2 - glc bit
|
||||
#define MUBUF_ACQUIRE 1
|
||||
#define MUBUF_RELEASE 0
|
||||
#endif
|
||||
|
||||
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
|
||||
|
||||
// Number of atoms (4xf16x2_t) processed by a single thread
|
||||
static constexpr int kAtoms = 8;
|
||||
|
||||
// We use a workgroup of 256 threads
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kAtomStride = kBlockSize;
|
||||
|
||||
// Size and atom stride of source/destination data that the block will
|
||||
// process.
|
||||
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
|
||||
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
|
||||
|
||||
// Max number of blocks. 304 CUs on MI300
|
||||
static constexpr int kMaxNumBlocks = 304 * 4;
|
||||
|
||||
// Standard CDNA wavefront size.
|
||||
static constexpr int kWavefront = 64;
|
||||
|
||||
// 256 thread, 4 wavefronts.
|
||||
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
|
||||
|
||||
// Number of threads in a group for quantization
|
||||
// It corresponds to 32 F16 elements in quantization block
|
||||
static constexpr int kThreadGroupSize = 8;
|
||||
|
||||
// Methods
|
||||
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) {
|
||||
return ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
union BufferResource {
|
||||
__quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {}
|
||||
|
||||
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size)
|
||||
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
|
||||
|
||||
int32x4_t descriptor;
|
||||
struct {
|
||||
void* address; // 8B, out of which first 48b is address, and 16b is stride
|
||||
// (unused)
|
||||
uint32_t range; // Byte range for the buffer resource
|
||||
uint32_t config; // Constant, DFMT=32b
|
||||
};
|
||||
};
|
||||
|
||||
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
|
||||
int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void
|
||||
buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm(
|
||||
"llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
|
||||
#if defined(__gfx942__)
|
||||
if (value) {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
|
||||
} else {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
union bf162_int_union {
|
||||
int i;
|
||||
nv_bfloat162 bf2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, int32x4_t* B) {
|
||||
int32x4_t& tR_fragment = A[0];
|
||||
int32x4_t& tA_fragment = B[0];
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(int32x4_t* A, int32x4_t* B) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tA[i] = __hadd2(tA[i], tB[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmax2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_min(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmin2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
|
||||
half2 wmaxh2 = __builtin_bit_cast(half2, a);
|
||||
half2 wminh2 = __builtin_bit_cast(half2, b);
|
||||
half2 wblockmaxh2;
|
||||
|
||||
wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
|
||||
wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
|
||||
return __builtin_bit_cast(int, wblockmaxh2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
|
||||
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_add(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hadd2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_sub(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
|
||||
int result;
|
||||
|
||||
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hsub2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_mul(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
|
||||
nv_bfloat162 tR = __hmul2(*tA, *tB);
|
||||
return *(reinterpret_cast<int*>(&tR));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_rcp(int a);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
|
||||
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
|
||||
bf162_int_union A, R;
|
||||
A.i = a;
|
||||
R.bf2 = h2rcp(A.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
// changes dtype
|
||||
__quickreduce_device_inline__ float T2float_cast(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
|
||||
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
|
||||
|
||||
int wmax, wmin, wblockmax;
|
||||
int a, b;
|
||||
a = packed_max<T>(atom[0], atom[1]);
|
||||
b = packed_max<T>(atom[2], atom[3]);
|
||||
|
||||
wmax = packed_max<T>(a, b);
|
||||
|
||||
a = packed_min<T>(atom[0], atom[1]);
|
||||
b = packed_min<T>(atom[2], atom[3]);
|
||||
|
||||
wmin = packed_min<T>(a, b);
|
||||
|
||||
// Reduce the max among a group of threads
|
||||
// Note: This is basically 2 blocks of values setup as the
|
||||
// upper/lower halves of the f16x2_t
|
||||
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
|
||||
int x = __shfl_down(wmax, i);
|
||||
wmax = packed_max<T>(wmax, x);
|
||||
|
||||
int y = __shfl_down(wmin, i);
|
||||
wmin = packed_min<T>(wmin, y);
|
||||
}
|
||||
wblockmax = packed_abs_max<T>(wmax, wmin);
|
||||
// Share with the cohort
|
||||
wblockmax = __shfl(wblockmax, group_leader);
|
||||
return wblockmax;
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quickreduce
|
||||
235
sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h
Normal file
235
sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h
Normal file
@@ -0,0 +1,235 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/dtk_macros.h>
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "quick_all_reduce.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
|
||||
T const* A,
|
||||
T* B,
|
||||
uint32_t N,
|
||||
uint32_t num_blocks,
|
||||
int rank,
|
||||
uint8_t** dbuffer_list,
|
||||
uint32_t data_offset,
|
||||
uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() {
|
||||
return world_size;
|
||||
}
|
||||
int get_rank() {
|
||||
return rank;
|
||||
}
|
||||
bool status() {
|
||||
return initialized;
|
||||
}
|
||||
hipIpcMemHandle_t const get_handle() {
|
||||
return buffer_ipc_handle;
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(
|
||||
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(hipGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* this file is used to test mscclpp_allreduce.cu using mpirun
|
||||
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
|
||||
usage:
|
||||
cd PATH-TO-THIS-FILE
|
||||
export MPI_HOME=/usr/local/mpi
|
||||
# export MPI_HOME=/opt/hpcx/ompi/
|
||||
export MSCCLPP_HOME=/workspace/test/mscclpp
|
||||
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
|
||||
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
|
||||
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
|
||||
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
|
||||
|
||||
/opt/hpcx/ompi/bin/
|
||||
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
|
||||
--map-by ppr:8:node \
|
||||
--mca btl_openib_warn_no_device_params_found 0 \
|
||||
--mca btl_tcp_if_include bond0 \
|
||||
--allow-run-as-root -np 8 \
|
||||
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
|
||||
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
|
||||
*/
|
||||
#include <mpi.h>
|
||||
#include <thrust/detail/raw_pointer_cast.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#ifndef CHECK_CUDA_SUCCESS
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
template <typename T>
|
||||
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
|
||||
return fabs(a - b) <= (atol + rtol * fabs(b));
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// init mpi
|
||||
MPI_Init(&argc, &argv);
|
||||
printf("MPI Initialized.\n");
|
||||
int nranks, rank;
|
||||
|
||||
// get work size and rank id
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
cudaSetDevice(rank);
|
||||
printf("nranks: %d, rank: %d\n", nranks, rank);
|
||||
|
||||
// init host and device buffers
|
||||
using T = float;
|
||||
using ReduceT = float;
|
||||
const size_t num_elems = 2 * 1024 * 1024;
|
||||
std::vector<T> host_buf(num_elems);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_buf[i] = T(i + rank);
|
||||
}
|
||||
thrust::device_vector<T> device_buf(host_buf);
|
||||
const size_t buf_size_in_bytes = num_elems * sizeof(T);
|
||||
std::vector<T> host_result_buf(num_elems);
|
||||
thrust::device_vector<T> device_result_buf(host_result_buf);
|
||||
|
||||
std::vector<T> host_scratch_buf(num_elems * 8);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_scratch_buf[i] = 1;
|
||||
}
|
||||
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
|
||||
std::vector<T> host_put_buf(num_elems);
|
||||
thrust::device_vector<T> device_put_buf(host_put_buf);
|
||||
|
||||
mscclpp::UniqueId unique_id;
|
||||
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
std::vector<int64_t> rank_to_node(nranks);
|
||||
std::vector<int64_t> rank_to_ib(nranks);
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
rank_to_node[i] = i / 8;
|
||||
rank_to_ib[i] = i % 8;
|
||||
}
|
||||
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
|
||||
if (nranks == 8) {
|
||||
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
} else if (nranks == 16) {
|
||||
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
|
||||
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
thrust::raw_pointer_cast(device_put_buf.data()),
|
||||
buf_size_in_bytes,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
}
|
||||
|
||||
// check result correctness
|
||||
thrust::host_vector<T> host_buf_result = device_result_buf;
|
||||
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
|
||||
bool nan_detected = false;
|
||||
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
|
||||
if (std::isnan(float(host_buf_result[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
|
||||
num_results_error_atol_1e_3_rtol_1e_3++;
|
||||
}
|
||||
}
|
||||
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
|
||||
|
||||
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
|
||||
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user