Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
@@ -0,0 +1,180 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/hip/Exceptions.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
|
||||
// fake pointer type, must match fptr_t type in ops.h
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
hipIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
hipStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, hipMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handles =
|
||||
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||
return {handles, std::move(offsets)};
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
|
||||
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
|
||||
inp.data_ptr()));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
torch::Tensor allocate_meta_buffer(int64_t size) {
|
||||
auto device_index = c10::hip::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(hipStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(torch::kI8)
|
||||
.device(torch::kCUDA, device_index);
|
||||
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> get_device_bdf(int dev) {
|
||||
char busIdStr[] = "0000:00:00.0";
|
||||
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
|
||||
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
|
||||
bdf.resize(bdf.size() - 1); // remove trailing NULL
|
||||
return bdf;
|
||||
}
|
||||
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
@@ -0,0 +1,582 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
hipError_t e = cmd; \
|
||||
if (e != hipSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
struct __align__(16) RankData {
|
||||
const void* ptrs[8];
|
||||
};
|
||||
#else
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
#endif
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
#ifdef USE_ROCM
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
#else
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
#endif
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal* meta,
|
||||
void* rank_data,
|
||||
size_t rank_data_sz,
|
||||
const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(hipIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(hipIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (hipPointerGetAttribute(
|
||||
&base_ptr,
|
||||
#ifdef USE_ROCM
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#else
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#endif
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
hipStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int size,
|
||||
#ifndef USE_ROCM
|
||||
int threads = 512,
|
||||
int block_limit = 36){
|
||||
#else
|
||||
int threads = 512,
|
||||
int block_limit = 16) {
|
||||
#endif
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
hipStreamCaptureStatus status;
|
||||
CUDACHECK(hipStreamIsCapturing(stream, &status));
|
||||
if (status == hipStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = ::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL( \
|
||||
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
}; // namespace vllm
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
545
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
Normal file
545
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
Normal file
@@ -0,0 +1,545 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// reference:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <tuple>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
|
||||
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
|
||||
uint32_t flag;
|
||||
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
|
||||
uint32_t flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
namespace trt_llm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type Converter that packs data format to 128 bits data type
|
||||
//
|
||||
using PackedFloat = union {
|
||||
int4 packed;
|
||||
float unpacked[4];
|
||||
};
|
||||
|
||||
using PackedHalf = union {
|
||||
int4 packed;
|
||||
half2 unpacked[4];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PackedOn16Bytes {};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<float> {
|
||||
using Type = PackedFloat;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<half> {
|
||||
using Type = PackedHalf;
|
||||
};
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
using PackedBFloat16 = union {
|
||||
int4 packed;
|
||||
__nv_bfloat162 unpacked[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<__nv_bfloat16> {
|
||||
using Type = PackedBFloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
template <typename T>
|
||||
inline __device__ int4 add128b(T& a, T& b) {
|
||||
T c;
|
||||
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
|
||||
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
|
||||
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
|
||||
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx) {
|
||||
// After this function, at least one block in each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, world_size]
|
||||
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
|
||||
|
||||
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
|
||||
size_t offset = (flag % 2) ? world_size : 0;
|
||||
|
||||
if (bidx == 0) {
|
||||
st_flag_release(flag, signals[tidx] + offset + local_rank);
|
||||
}
|
||||
|
||||
// All blocks check that corresponding block 0 on other GPUs have set the flag
|
||||
// No deadlock because block #0 is always the first block started
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <bool start, bool need_fence = false>
|
||||
__inline__ __device__ void block_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx,
|
||||
int const grid_size) {
|
||||
if constexpr (!start) {
|
||||
__syncthreads();
|
||||
}
|
||||
// After this function, the block of id == bidx of each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
|
||||
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
|
||||
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
|
||||
|
||||
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
|
||||
uint32_t flag_block_offset = world_size + bidx * world_size;
|
||||
|
||||
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
|
||||
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
|
||||
// Blocks check that corresponding blocks on other GPUs have also set the flag
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag) {
|
||||
}
|
||||
} else {
|
||||
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
|
||||
while (ld_flag_volatile(peer_barrier_d) != flag) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) {
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// GPU 0 | B0 | B1 | B2 | B3 |
|
||||
// GPU 1 | B0 | B1 | B2 | B3 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
|
||||
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunk is it responsible for into all other GPUs:
|
||||
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
|
||||
// 2. block sync so the block is shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
int const grid_size = gridDim.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
|
||||
// Start and end offsets of the thread
|
||||
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
if constexpr (COPY_INPUT) {
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
// Copy from local buffer to shareable buffer
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
|
||||
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
|
||||
}
|
||||
}
|
||||
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedStruct vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedStruct sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
|
||||
// Always reduce from rank 0 to ensure stable reduce order.
|
||||
sums.packed = add128b(sums, vals[rank]);
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
|
||||
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
|
||||
// GPU 0 | B0 | B1 | B0 | B1 |
|
||||
// GPU 1 | B0 | B1 | B0 | B1 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
|
||||
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
|
||||
// part (the first half of the message, see GPU responsibility row above)
|
||||
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
|
||||
// where GPU 1 is responsible: the second half of the message.
|
||||
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
|
||||
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
|
||||
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
// to be read.
|
||||
//
|
||||
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
|
||||
// However, it's only responsible for the summation of a single chunk.
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
|
||||
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
|
||||
// 2. block sync so the blocks have been shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
// 4. block barrier (corresponding blocks have finished reduction)
|
||||
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
|
||||
// written at index 0 of 2nd dim)
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
int const grid_size = gridDim.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int PACKED_ELTS = 16 / sizeof(T);
|
||||
using PackedType = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
|
||||
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
||||
|
||||
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
|
||||
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
T* buffers[RANKS_PER_NODE];
|
||||
T* buffers_unorder[RANKS_PER_NODE];
|
||||
int ranks[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
// A mapping of the ranks to scatter reads as much as possible
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
ranks[ii] = rank;
|
||||
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
|
||||
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
|
||||
}
|
||||
|
||||
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
if constexpr (COPY_INPUT) {
|
||||
// Copy all blocks from local buffer to shareable buffer
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total) {
|
||||
continue;
|
||||
}
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
size_t const responsible_block_offset = local_offset + params.rank_offset;
|
||||
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
|
||||
// Always reduce from rank 0 to ensure stable reduce order.
|
||||
sums.packed = add128b(sums, vals[rank]);
|
||||
}
|
||||
|
||||
// Store to the local buffer or tmp buffer
|
||||
if constexpr (COPY_INPUT) {
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
|
||||
} else {
|
||||
*reinterpret_cast<int4*>(¶ms.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
block_barrier<false, true>(
|
||||
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Gather all needed elts from other intra-node ranks
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
|
||||
// use round-robin gathering from other ranks
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total) {
|
||||
continue;
|
||||
}
|
||||
if constexpr (COPY_INPUT) {
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
|
||||
} else {
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
|
||||
*reinterpret_cast<int4*>(¶ms.tmp_result_buffers[ranks[ii]][offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline int divUp(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
inline int roundUp(int a, int n) {
|
||||
return divUp(a, n) * n;
|
||||
}
|
||||
|
||||
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
|
||||
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
assert(params.elts_total % elts_per_thread == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
|
||||
params.elts_per_rank = params.elts_total;
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::TWOSHOT: {
|
||||
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
|
||||
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
|
||||
|
||||
/*
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
*/
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
|
||||
blocks_per_grid += 1;
|
||||
}
|
||||
|
||||
threads_per_block = total_threads / blocks_per_grid;
|
||||
|
||||
// NOTE: need to adjust here
|
||||
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
|
||||
size_t iter_factor = 1;
|
||||
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
|
||||
iter_factor += 1;
|
||||
}
|
||||
blocks_per_grid /= iter_factor;
|
||||
}
|
||||
params.elts_per_rank = params.elts_total / params.ranks_per_node;
|
||||
params.rank_offset = params.local_rank * params.elts_per_rank;
|
||||
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Algorithm not supported here.");
|
||||
}
|
||||
|
||||
return std::make_tuple(blocks_per_grid, threads_per_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
|
||||
void dispatchARKernels(
|
||||
AllReduceStrategyType algo,
|
||||
AllReduceParams& param,
|
||||
int blocks_per_grid,
|
||||
int threads_per_block,
|
||||
cudaStream_t stream) {
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::TWOSHOT: {
|
||||
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool COPY_INPUT>
|
||||
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
|
||||
size_t elts_per_thread = 16 / sizeof(T);
|
||||
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
|
||||
switch (param.ranks_per_node) {
|
||||
case 2:
|
||||
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 4:
|
||||
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 6:
|
||||
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
case 8:
|
||||
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
if (param.is_capturing) {
|
||||
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
|
||||
} else {
|
||||
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(cudaGetLastError());
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(
|
||||
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
if (params.elts_total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (data_type) {
|
||||
case at::ScalarType::Float:
|
||||
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
|
||||
break;
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16:
|
||||
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
} // namespace trt_llm
|
||||
226
sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
Normal file
226
sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
Normal file
@@ -0,0 +1,226 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace trt_llm;
|
||||
|
||||
using fptr_t = int64_t;
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
|
||||
class AllReduceMeta {
|
||||
public:
|
||||
AllReduceMeta(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
this->rank_id = (int)rank_id;
|
||||
this->world_size = (int)world_size;
|
||||
this->barrier_in = barrier_in;
|
||||
this->barrier_out = barrier_out;
|
||||
this->tmp_result_buffers = tmp_result_buffers;
|
||||
|
||||
this->rank_data_base = reinterpret_cast<RankData*>(rank_data.data_ptr());
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
data.ptrs[i] = (void*)buffers[i];
|
||||
}
|
||||
auto d_data = this->rank_data_base++;
|
||||
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
this->buffers = d_data;
|
||||
}
|
||||
|
||||
~AllReduceMeta() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
int world_size;
|
||||
int rank_id;
|
||||
std::vector<fptr_t> barrier_in;
|
||||
std::vector<fptr_t> barrier_out;
|
||||
std::vector<fptr_t> tmp_result_buffers;
|
||||
int barrier_flag = 1;
|
||||
RankData* buffers;
|
||||
RankData* rank_data_base;
|
||||
std::vector<void*> graph_unreg_buffers;
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
};
|
||||
|
||||
// Get the number of bits for a given data type.
|
||||
inline int get_bits(at::ScalarType dtype) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return 32;
|
||||
case at::ScalarType::Half:
|
||||
case at::ScalarType::BFloat16:
|
||||
return 16;
|
||||
default:
|
||||
assert(false && "Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
// Check if customized all-reduce kernels can be applied.
|
||||
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
|
||||
// The customized all-reduce kernel has the following requirement(s).
|
||||
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
|
||||
}
|
||||
|
||||
fptr_t init_custom_ar(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
|
||||
return (fptr_t)m;
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
auto num_buffers = m->graph_unreg_buffers.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = m->graph_unreg_buffers[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) {
|
||||
assert(false && "failed to get pointer attr");
|
||||
}
|
||||
|
||||
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
std::vector<int64_t> bytes(handles.begin(), handles.end());
|
||||
return std::make_pair(bytes, offsets);
|
||||
}
|
||||
|
||||
char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
|
||||
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
std::vector<std::string> handle_bytes;
|
||||
handle_bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
handle_bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
auto num_buffers = m->graph_unreg_buffers.size();
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = m->graph_unreg_buffers[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < m->world_size; j++) {
|
||||
if (j != m->rank_id) {
|
||||
char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
|
||||
m->rank_data_base += num_buffers;
|
||||
m->graph_unreg_buffers.clear();
|
||||
}
|
||||
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
auto num_elements = inp.numel();
|
||||
auto dtype = inp.scalar_type();
|
||||
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
|
||||
|
||||
// should be gurantee in python code
|
||||
assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT);
|
||||
assert(CanApplyCustomAllReduce(num_elements, dtype));
|
||||
|
||||
// Initialize the all-reduce kernel arguments.
|
||||
int world_size = m->world_size;
|
||||
|
||||
AllReduceParams params;
|
||||
params.ranks_per_node = world_size;
|
||||
params.rank = m->rank_id;
|
||||
params.local_rank = m->rank_id;
|
||||
params.local_input_buffer_ptr = inp.data_ptr();
|
||||
params.local_output_buffer_ptr = out.data_ptr();
|
||||
params.elts_total = inp.numel();
|
||||
params.elts_size = inp.element_size();
|
||||
params.barrier_flag = ++(m->barrier_flag);
|
||||
|
||||
cudaStreamCaptureStatus status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
|
||||
params.is_capturing = (status == cudaStreamCaptureStatusActive);
|
||||
if (params.is_capturing) {
|
||||
params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size();
|
||||
m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr);
|
||||
} else {
|
||||
params.peer_comm_buffer_ptrs = m->buffers;
|
||||
}
|
||||
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.tmp_result_buffers[i] = reinterpret_cast<uint32_t*>(m->tmp_result_buffers[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
|
||||
}
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
|
||||
}
|
||||
|
||||
auto data_type = out.scalar_type();
|
||||
trtCustomAllReduce(params, data_type, strategy, stream);
|
||||
}
|
||||
154
sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
Normal file
154
sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
Normal file
@@ -0,0 +1,154 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
__global__ void lightning_attention_decode_kernel(
|
||||
const T* __restrict__ q, // [b, h, 1, d]
|
||||
const T* __restrict__ k, // [b, h, 1, d]
|
||||
const T* __restrict__ v, // [b, h, 1, e]
|
||||
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||
const float* __restrict__ slope, // [h, 1, 1]
|
||||
T* __restrict__ output, // [b, h, 1, e]
|
||||
float* __restrict__ new_kv, // [b, h, d, e]
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int qk_dim,
|
||||
const int v_dim) {
|
||||
extern __shared__ char smem[];
|
||||
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
|
||||
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||
T* __restrict__ output_shared =
|
||||
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
const int32_t current_head = blockIdx.x;
|
||||
const int32_t b = current_head / num_heads;
|
||||
const int32_t h = current_head % num_heads;
|
||||
|
||||
if (b >= batch_size) return;
|
||||
|
||||
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
|
||||
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||
|
||||
// Load q, k, v into shared memory
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
q_shared[d] = q[qk_offset + d];
|
||||
k_shared[d] = k[qk_offset + d];
|
||||
}
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
v_shared[e] = v[v_offset + e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const float ratio = expf(-1.0f * slope[h]);
|
||||
|
||||
// Compute new_kv
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
const T k_val = k_shared[d];
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
const int past_kv_idx = kv_offset + d * v_dim + e;
|
||||
const T v_val = v_shared[e];
|
||||
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
new_kv_shared[shared_idx] = new_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store new_kv to global memory
|
||||
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||
const int d = idx / v_dim;
|
||||
const int e = idx % v_dim;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
const int global_idx = kv_offset + idx;
|
||||
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute output
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < qk_dim; ++d) {
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||
}
|
||||
output_shared[e] = static_cast<T>(sum);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store output to global memory
|
||||
if (tid == 0) {
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
output[v_offset + e] = output_shared[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv) {
|
||||
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
|
||||
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
|
||||
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
|
||||
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
|
||||
|
||||
auto batch_size = q.size(0);
|
||||
auto num_heads = q.size(1);
|
||||
auto qk_dim = q.size(3);
|
||||
auto v_dim = v.size(3);
|
||||
|
||||
dim3 block(THREADS_PER_BLOCK);
|
||||
dim3 grid(batch_size * num_heads);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
|
||||
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
|
||||
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
|
||||
q.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(),
|
||||
v.data_ptr<scalar_t>(),
|
||||
past_kv.data_ptr<float>(),
|
||||
slope.data_ptr<float>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
new_kv.data_ptr<float>(),
|
||||
batch_size,
|
||||
num_heads,
|
||||
qk_dim,
|
||||
v_dim);
|
||||
}));
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename ScaleTileIterator_,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol {
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(batch_stride_alpha_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_) {}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise),
|
||||
batch_stride_alpha(args.batch_stride_alpha),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D) {}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const with_bias_;
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(
|
||||
Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D,
|
||||
bool with_bias,
|
||||
bool per_token_quant,
|
||||
bool per_channel_quant,
|
||||
AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col,
|
||||
typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
with_bias_(with_bias),
|
||||
per_token_quant_(per_token_quant),
|
||||
per_channel_quant_(per_channel_quant),
|
||||
ptr_alpha_row_(ptr_alpha_row),
|
||||
ptr_alpha_col_(ptr_alpha_col),
|
||||
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
extent_real_(problem_size_real) {
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (per_channel_quant_) {
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_) {
|
||||
int thread_offset_row =
|
||||
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) {
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_) {
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
} else {
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kElementsPerAccess> bias_converter;
|
||||
OutputVector bias = reinterpret_cast<OutputVector*>(&fragment_C_)[column_idx];
|
||||
result = bias_accumulator_(result, bias_converter(bias));
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment bias_accumulator_(ComputeFragment const& accum, ComputeFragment const& bias) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputVector::kElements; ++i) {
|
||||
result[i] = accum[i] + bias[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,125 @@
|
||||
// Adapt from
|
||||
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
|
||||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,733 @@
|
||||
// clang-format off
|
||||
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm80.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using ElementBlockScale = ElementAccumulator;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.mma_promotion_interval,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||
auto tM = get<2>(gA_mkl.shape());
|
||||
auto tN = get<2>(gB_nkl.shape());
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorScaleA, class TensorScaleB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||
// all scales are valid, since we could have a partial tiles along M)
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Block scaling
|
||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||
Layout<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// Layout of warp group to thread mapping
|
||||
|
||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||
|
||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||
Int<NumThreadsPerWarpGroup>{});
|
||||
|
||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
37
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
37
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::
|
||||
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@@ -0,0 +1,356 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/device_kernel.h>
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat {
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args) {
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
} else {
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
|
||||
<< ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes) {
|
||||
if (!workspace) {
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result =
|
||||
cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,492 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/complex.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/fast_math.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor {
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRefA ref_A_,
|
||||
TensorRefB ref_B_,
|
||||
TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_,
|
||||
TensorRefC ref_C_,
|
||||
TensorRefC ref_D_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(GemmUniversalMode::kGemm),
|
||||
problem_size(problem_size_),
|
||||
batch_count(1),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_alpha_col(ref_alpha_col_),
|
||||
ref_alpha_row(ref_alpha_row_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_D(0),
|
||||
epilogue_visitor(epilogue_visitor_) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_alpha_col(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_alpha_col(nullptr),
|
||||
ptr_alpha_row(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0) {}
|
||||
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size),
|
||||
swizzle_log_tile(0),
|
||||
params_A(args.ref_A.layout()),
|
||||
params_B(args.ref_B.layout()),
|
||||
params_alpha_col(args.ref_alpha_col.layout()),
|
||||
params_alpha_row(args.ref_alpha_col.layout()),
|
||||
params_C(args.ref_C.layout()),
|
||||
params_D(args.ref_D.layout()),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(args.problem_size.k()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
ptr_alpha_col(args.ref_alpha_col.data()),
|
||||
ptr_alpha_row(args.ref_alpha_row.data()),
|
||||
ptr_C(args.ref_C.data()),
|
||||
ptr_D(args.ref_D.data()),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
epilogue_visitor(args.epilogue_visitor) {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) {
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage) {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
} else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
} else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
bool with_bias = true;
|
||||
if (params.ptr_C == nullptr) {
|
||||
with_bias = false;
|
||||
}
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
params.params_alpha_col,
|
||||
params.params_C,
|
||||
params.params_D,
|
||||
with_bias,
|
||||
true,
|
||||
true,
|
||||
params.ptr_alpha_row,
|
||||
params.ptr_alpha_col,
|
||||
params.ptr_C,
|
||||
params.ptr_D,
|
||||
threadblock_offset,
|
||||
blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) {
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
|
||||
run_kernel_(params, shared_storage);
|
||||
} else {
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage) {
|
||||
run_kernel<ArchTag>(params, shared_storage);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
55
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
Normal file
55
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <flashinfer/norm.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
auto device = input.device();
|
||||
CHECK_EQ(residual.device(), device);
|
||||
CHECK_EQ(weight.device(), device);
|
||||
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
|
||||
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
|
||||
CHECK_DIM(1, weight); // weight: (hidden_size)
|
||||
CHECK_EQ(input.size(0), residual.size(0));
|
||||
CHECK_EQ(input.size(1), residual.size(1));
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
unsigned int batch_size = input.size(0);
|
||||
unsigned int hidden_size = input.size(1);
|
||||
|
||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||
// support float16, bfloat16 and float32
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = norm::FusedAddRMSNorm(
|
||||
static_cast<c_type*>(input.data_ptr()),
|
||||
static_cast<c_type*>(residual.data_ptr()),
|
||||
static_cast<c_type*>(weight.data_ptr()),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
eps,
|
||||
torch_current_stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
172
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
Normal file
172
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
Normal file
@@ -0,0 +1,172 @@
|
||||
// References:
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex
|
||||
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu
|
||||
// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
static void check_group_count(
|
||||
const std::vector<torch::Tensor>& inputs,
|
||||
const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs) {
|
||||
TORCH_CHECK(
|
||||
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
|
||||
"The group count of inputs, weights and outputs should be the same.");
|
||||
}
|
||||
|
||||
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
|
||||
for (const auto& t : tensors) {
|
||||
TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same");
|
||||
TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory");
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int> get_dims(const std::vector<torch::Tensor>& tensors, int dim) {
|
||||
std::vector<int> results;
|
||||
for (const auto& t : tensors) {
|
||||
TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices");
|
||||
results.push_back(t.size(dim));
|
||||
}
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
static std::vector<int> get_strides(const std::vector<torch::Tensor>& tensors, int dim) {
|
||||
std::vector<int> results;
|
||||
for (const auto& t : tensors) {
|
||||
results.push_back(t.stride(dim));
|
||||
}
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
static void check_equal(const std::vector<int>& a, const std::vector<int>& b, const std::string& err_msg) {
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
TORCH_CHECK(a[i] == b[i], err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tensors) {
|
||||
std::vector<void*> ptrs;
|
||||
for (auto& t : tensors) {
|
||||
ptrs.push_back(t.data_ptr());
|
||||
}
|
||||
return std::move(ptrs);
|
||||
}
|
||||
|
||||
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
|
||||
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
|
||||
TORCH_CHECK(
|
||||
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
|
||||
CUBLAS_STATUS_SUCCESS);
|
||||
return gpu_ptrs;
|
||||
}
|
||||
|
||||
// We want compute input @ weight^T in row major
|
||||
// This is equivalent to computing weight @ input^T in col major
|
||||
// Cublas only accepts matrix in column major, so this arrangement is needed
|
||||
void cublas_grouped_gemm(
|
||||
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
|
||||
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
|
||||
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
|
||||
const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream) {
|
||||
TORCH_CHECK(
|
||||
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
|
||||
"cublas grouped_gemm can"
|
||||
"only be applied to float16 and bfloat16 dtype");
|
||||
|
||||
int group_count = inputs.size();
|
||||
check_group_count(inputs, weights, outputs);
|
||||
std::vector<int> group_size(group_count, 1);
|
||||
|
||||
// Make sure all tensors are on cuda and use the same dtype
|
||||
check_device_dtype(out_dtype, inputs);
|
||||
check_device_dtype(out_dtype, weights);
|
||||
check_device_dtype(out_dtype, outputs);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
// Weights should be transposed to (n, k) of column major
|
||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
||||
std::vector<cublasOperation_t> transb_array(group_count, CUBLAS_OP_N);
|
||||
|
||||
// Get dim arrays
|
||||
std::vector<int> m_array = get_dims(weights, 0);
|
||||
std::vector<int> n_array = get_dims(inputs, 0);
|
||||
std::vector<int> k_array = get_dims(inputs, 1);
|
||||
|
||||
// Make sure the dimensions in each group match
|
||||
std::vector<int> m_array1 = get_dims(outputs, 1);
|
||||
std::vector<int> n_array1 = get_dims(outputs, 0);
|
||||
std::vector<int> k_array1 = get_dims(weights, 1);
|
||||
check_equal(m_array, m_array1, "sizes don't match on m dimension");
|
||||
check_equal(n_array, n_array1, "sizes don't match on n dimension");
|
||||
check_equal(k_array, k_array1, "sizes don't match on k dimension");
|
||||
|
||||
// Get leading dimensions
|
||||
std::vector<int> lda_array = get_strides(weights, 0);
|
||||
std::vector<int> ldb_array = get_strides(inputs, 0);
|
||||
std::vector<int> ldc_array = get_strides(outputs, 0);
|
||||
|
||||
// Use default scaling factors
|
||||
std::vector<float> alpha_array(group_count, 1);
|
||||
std::vector<float> beta_array(group_count, 0);
|
||||
|
||||
std::vector<void*> a_array = get_tensor_ptrs(weights);
|
||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
||||
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
|
||||
// Should allocate tensors for storage of pointers
|
||||
torch::Tensor d_a = create_ptr_pointer(a_array, stream);
|
||||
torch::Tensor d_b = create_ptr_pointer(b_array, stream);
|
||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
auto status = cublasGemmGroupedBatchedEx(
|
||||
handle,
|
||||
transa_array.data(),
|
||||
transb_array.data(),
|
||||
m_array.data(),
|
||||
n_array.data(),
|
||||
k_array.data(),
|
||||
alpha_array.data(),
|
||||
(void**)d_a.data_ptr(),
|
||||
cuda_data_type,
|
||||
lda_array.data(),
|
||||
(void**)d_b.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldb_array.data(),
|
||||
beta_array.data(),
|
||||
(void**)d_c.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldc_array.data(),
|
||||
group_count,
|
||||
group_size.data(),
|
||||
CUBLAS_COMPUTE_32F);
|
||||
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
|
||||
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
|
||||
}
|
||||
226
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
226
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
@@ -0,0 +1,226 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||
#include <cutlass/gemm/thread/mma.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
};
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
|
||||
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
|
||||
}
|
||||
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major");
|
||||
TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched");
|
||||
TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version >= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
} else {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
859
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
859
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
@@ -0,0 +1,859 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||
#include <cutlass/gemm/thread/mma.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CtaShape,
|
||||
typename WarpShape,
|
||||
int Stages,
|
||||
bool WithBias,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
|
||||
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
|
||||
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
|
||||
struct DeviceGemmFp8RowwiseSm89 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
using ElementA = ElementType;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementType;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = OutElementType;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ElementOutput = OutElementType;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
using ElementAccumulator = AccumElementType;
|
||||
using ElementComputeEpilogue = float;
|
||||
using ArchTag = cutlass::arch::Sm89;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
// Number of epilogue stages in EVT
|
||||
static constexpr int EVTEpilogueStages = 1;
|
||||
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::
|
||||
OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
|
||||
|
||||
// Definition of EVT
|
||||
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using bScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
|
||||
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
|
||||
|
||||
using ComputeAScale = cutlass::epilogue::threadblock::
|
||||
VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using aScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
|
||||
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
|
||||
|
||||
// With bias
|
||||
using biasSrc =
|
||||
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
|
||||
using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add,
|
||||
ElementC,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EpilogueAScaleWithBias =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
|
||||
|
||||
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap,
|
||||
ElementC,
|
||||
cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, _1, _0>>;
|
||||
using EpilogueStore = typename cutlass::platform::conditional<
|
||||
WithBias,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
|
||||
|
||||
using EpilogueOp = EpilogueStore;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
FP8MathOperator,
|
||||
EVTEpilogueStages>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm89_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||
ElementOutput const* ptr_bias = nullptr;
|
||||
if constexpr (WithBias) {
|
||||
TORCH_CHECK(bias.has_value())
|
||||
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||
}
|
||||
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||
|
||||
typename Gemm::Arguments args(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // Mode
|
||||
{m, n, k}, // Problem size
|
||||
1, // Split-k factor
|
||||
{}, // Epilogue args
|
||||
ptr_a, // a pointer
|
||||
ptr_b, // b pointer
|
||||
nullptr, // c pointer (unused)
|
||||
nullptr, // d pointer (unused)
|
||||
m * k, // batch stride a (unused)
|
||||
n * k, // batch stride b (unused)
|
||||
m * n, // batch stride c (unused)
|
||||
m * n, // batch stride d (unused)
|
||||
lda, // stride a
|
||||
ldb, // stride b
|
||||
ldc, // stride c (unused)
|
||||
ldc); // stride d (unused)
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
} else {
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm89_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
|
||||
void sm89_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
if (bias) {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
true>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
false>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm89_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const n = out.size(1);
|
||||
|
||||
if (m == 1) {
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 16) {
|
||||
// M in (1, 16]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
// M in (16, 64]
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
// M in (64, 128]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
// M in (128, 256]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 128>,
|
||||
cutlass::gemm::GemmShape<64, 32, 128>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 512) {
|
||||
// M in (256, 512)
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
// M in (512, inf)
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
3>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
bool WithBias = false>
|
||||
struct DeviceGemmFp8RowwiseSm90 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = ElementType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = void; // Element type for C matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
|
||||
// units of elements (up to 16 bytes)
|
||||
|
||||
// Output matrix configuration
|
||||
using ElementOutput = OutElementType; // Element type for output matrix operands
|
||||
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// // Auxiliary matrix configuration and other fusion types
|
||||
// using ElementBias = float;
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using ElementComputeEpilogue = float;
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
|
||||
static constexpr bool PONG = false;
|
||||
static constexpr bool FAST_ACCUM = true;
|
||||
static constexpr bool USE_BIAS = false;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
|
||||
// based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
|
||||
// setting in the Collective Builder
|
||||
// Implement rowwise scaling epilogue.
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementOutput,
|
||||
ElementOutput,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue, // First stage output type.
|
||||
ElementComputeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
AlignmentOutput,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
|
||||
using SlowAccum = DefaultSchedule;
|
||||
using FastAccum = FastPongSchedule; // Default apply Pingpong
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm90_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||
ElementOutput const* ptr_bias = nullptr;
|
||||
if constexpr (WithBias) {
|
||||
TORCH_CHECK(bias.has_value())
|
||||
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||
}
|
||||
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{ptr_a, stride_a, ptr_b, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
ptr_d,
|
||||
stride_d}};
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue.thread = {
|
||||
{ptr_scales_a},
|
||||
{
|
||||
{ptr_scales_b},
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_bias},
|
||||
{}, // Multiplies
|
||||
};
|
||||
} else {
|
||||
args.epilogue.thread = {
|
||||
{ptr_scales_a},
|
||||
{
|
||||
{ptr_scales_b},
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{}, // Multiplies
|
||||
};
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm90_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <
|
||||
typename OutType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename TileSchedulerType>
|
||||
void sm90_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias,
|
||||
bool fast_accum = true,
|
||||
bool use_persistent = false) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (bias) {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
using BasicTileScheduler = void;
|
||||
if (m <= 1) {
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
FastBasicScheduler,
|
||||
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
if (m <= 64) {
|
||||
// m in [1, 64]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
// m in (64, 256]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 1024) {
|
||||
// m in (256, 1024]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
// m in (1024, inf)
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version >= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||
if (sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
599
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
599
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
@@ -0,0 +1,599 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
||||
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
int NumStages>
|
||||
void cutlass_int8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
using DefaultGemmConf = cutlass::gemm::device::
|
||||
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
|
||||
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
DefaultGemmConf::kAlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
DefaultGemmConf::kAlignmentB,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
true,
|
||||
typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
|
||||
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
|
||||
cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
ElementCompute>;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
|
||||
ThreadblockShape,
|
||||
GemmKernel_::kThreadCount,
|
||||
AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
EpilogueOutputOp>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
int64_t lda = mat_a.stride(0);
|
||||
int64_t ldb = mat_b.stride(1);
|
||||
int64_t ldd = out.stride(0);
|
||||
|
||||
ElementOutput* bias_ptr = nullptr;
|
||||
int64_t ldc = 0;
|
||||
if (bias) {
|
||||
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::Params linearScalingParams;
|
||||
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
|
||||
|
||||
typename Gemm::Arguments args{
|
||||
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm75_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
if (m <= 32) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 64) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm80_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128 && n < 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
bool WithBias>
|
||||
void cutlass_int8_scaled_mm_sm90(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
// Scale
|
||||
using Compute0 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentOutput,
|
||||
EpilogueScheduleType,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
Stages,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{a_ptr, stride_a, b_ptr, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
o_ptr,
|
||||
stride_d}};
|
||||
|
||||
if constexpr (WithBias) {
|
||||
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{bias_ptr},
|
||||
{},
|
||||
};
|
||||
} else {
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{},
|
||||
};
|
||||
}
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
|
||||
void sm90_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
if (bias) {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, false>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput>
|
||||
void sm90_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 32) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _256>,
|
||||
Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 4096) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
if (sm_version >= 75 && sm_version < 80) {
|
||||
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (sm_version == 90) {
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
// cutlass 3.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#else
|
||||
// fallback to cutlass 2.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
125
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
125
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
|
||||
float max_value = 0.0f;
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / vec_size;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = static_cast<float>(input[idx]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_tensor_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t num_elements) {
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
const float scale_val = 1.0f / (*scale);
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / vec_size;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * vec_size);
|
||||
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
output[i * vec_size + j] = output_arr[j];
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
output[idx] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const int block_size = 256;
|
||||
const int num_elements = input.numel();
|
||||
const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (is_static == false) {
|
||||
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||
}
|
||||
|
||||
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
num_elements);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
105
sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
Normal file
105
sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) {
|
||||
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
||||
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
||||
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
||||
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
||||
return smem[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_group_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
void* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const float eps,
|
||||
const float fp8_min,
|
||||
const float fp8_max) {
|
||||
const int groups_per_block = 16;
|
||||
const int block_group_id = blockIdx.x * groups_per_block;
|
||||
const int tid = threadIdx.x;
|
||||
const int local_group_id = tid / 16;
|
||||
const int local_tid = tid % 16;
|
||||
|
||||
__shared__ float s_absmax[16][17];
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
if (block_group_id + local_group_id < num_groups) {
|
||||
const T* group_input = input + (block_group_id + local_group_id) * group_size;
|
||||
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + (block_group_id + local_group_id) * group_size;
|
||||
float* scale_output = output_s + block_group_id + local_group_id;
|
||||
|
||||
for (int i = local_tid; i < group_size; i += 16) {
|
||||
float val = static_cast<float>(group_input[i]);
|
||||
float abs_val = fabsf(val);
|
||||
local_absmax = fmaxf(local_absmax, abs_val);
|
||||
}
|
||||
|
||||
s_absmax[local_group_id][local_tid] = local_absmax;
|
||||
__syncthreads();
|
||||
|
||||
if (local_tid < 8) {
|
||||
GroupReduceMax(&s_absmax[local_group_id][0], local_tid);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float group_absmax = s_absmax[local_group_id][0];
|
||||
const float y_s = group_absmax / fp8_max;
|
||||
|
||||
if (local_tid == 0) {
|
||||
*scale_output = y_s;
|
||||
}
|
||||
|
||||
for (int i = local_tid; i < group_size; i += 16) {
|
||||
float val = static_cast<float>(group_input[i]);
|
||||
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
|
||||
group_output[i] = FP8_TYPE(q_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
CHECK_EQ(input.numel() % group_size, 0);
|
||||
|
||||
dim3 grid((num_groups + 15) / 16);
|
||||
dim3 block(256);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
output_q.data_ptr(),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
group_size,
|
||||
num_groups,
|
||||
(float)eps,
|
||||
(float)fp8_min,
|
||||
(float)fp8_max);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
111
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
111
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_dim;
|
||||
FP8_TYPE* token_output = output_q + token_idx * hidden_dim;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int i = tid; i < hidden_dim; i += block_dim) {
|
||||
float val = static_cast<float>(token_input[i]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
__shared__ float block_max;
|
||||
if (tid == 0) {
|
||||
block_max = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = block_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_val = 1.0f / block_max;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = hidden_dim / vec_size;
|
||||
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * vec_size);
|
||||
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
token_output[i * vec_size + j] = output_arr[j];
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
token_output[idx] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
token_output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const auto input_sizes = input.sizes();
|
||||
const int64_t num_tokens = input_sizes[0];
|
||||
const int64_t hidden_dim = input_sizes[1];
|
||||
|
||||
const int block_size = 128;
|
||||
const int num_blocks = num_tokens;
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
139
sgl-kernel/csrc/moe/moe_align_kernel.cu
Normal file
139
sgl-kernel/csrc/moe/moe_align_kernel.cu
Normal file
@@ -0,0 +1,139 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t experts_per_warp,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
|
||||
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CHECK(num_experts % WARP_SIZE == 0);
|
||||
int experts_per_warp = num_experts / WARP_SIZE;
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
|
||||
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
251
sgl-kernel/csrc/speculative/eagle_utils.cu
Normal file
251
sgl-kernel/csrc/speculative/eagle_utils.cu
Normal file
@@ -0,0 +1,251 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
|
||||
__global__ void build_tree_efficient(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int32_t* verified_seq_len,
|
||||
bool* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int64_t* retrive_next_token,
|
||||
int64_t* retrive_next_sibling,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||
for (int i = 0; i < bid; i++) {
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||
tree_mask[token_tree_idx + i] = false;
|
||||
}
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
|
||||
int retrive_index_offset = bid * draft_token_num;
|
||||
for (int i = draft_token_num - 1; i > 0; --i) {
|
||||
int current_token_idx = retrive_index_offset + i;
|
||||
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
||||
int parent_position = 0;
|
||||
if (parent_tb_idx > 0) {
|
||||
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (; parent_position < draft_token_num; ++parent_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
||||
++parent_position;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent_position == draft_token_num) {
|
||||
printf(
|
||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
||||
"will be dropped.");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
} else {
|
||||
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
||||
}
|
||||
}
|
||||
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
||||
} else {
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
position += 1;
|
||||
tree_mask[token_tree_idx + cur_position] = true;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
dim3 grid(bs);
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num));
|
||||
}
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token, depth + 2]
|
||||
__global__ void build_tree(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int32_t* verified_seq_len,
|
||||
bool* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
||||
for (int i = 0; i < bid; i++) {
|
||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
||||
tree_mask[token_tree_idx + i] = false;
|
||||
}
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
|
||||
return;
|
||||
}
|
||||
|
||||
int depends_order[10];
|
||||
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
depends_order[position] = cur_position + 1;
|
||||
position += 1;
|
||||
tree_mask[token_tree_idx + cur_position] = true;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (cur_position == draft_token_num) {
|
||||
printf(
|
||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
||||
"will be dropped.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
|
||||
int is_leaf = 0;
|
||||
for (int i = 1; i < draft_token_num; i++) {
|
||||
if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
|
||||
is_leaf++;
|
||||
}
|
||||
}
|
||||
if (is_leaf == 1) {
|
||||
for (int i = 0; i < position; i++) {
|
||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
|
||||
depends_order[i] + bid * draft_token_num;
|
||||
}
|
||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
dim3 grid(bs);
|
||||
dim3 block(draft_token_num);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num));
|
||||
}
|
||||
138
sgl-kernel/csrc/speculative/speculative_sampling.cu
Normal file
138
sgl-kernel/csrc/speculative/speculative_sampling.cu
Normal file
@@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
* Copyright (c) 2025 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
#include "speculative_sampling.cuh"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// predicts: [tot_num_draft_tokens]
|
||||
// accept_index: [bs, num_spec_step]
|
||||
// accept_token_num: [bs]
|
||||
// candidates: [bs, num_draft_tokens]
|
||||
// retrive_index: [bs, num_draft_tokens]
|
||||
// retrive_next_token: [bs, num_draft_tokens]
|
||||
// retrive_next_sibling: [bs, num_draft_tokens]
|
||||
// uniform_samples: [bs, num_draft_tokens]
|
||||
// target_probs: [bs, num_draft_tokens, vocab_size]
|
||||
void tree_speculative_sampling_target_only(
|
||||
at::Tensor predicts,
|
||||
at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream = 0) {
|
||||
CHECK_INPUT(candidates);
|
||||
CHECK_INPUT(retrive_index);
|
||||
CHECK_INPUT(retrive_next_token);
|
||||
CHECK_INPUT(retrive_next_sibling);
|
||||
CHECK_INPUT(uniform_samples);
|
||||
CHECK_INPUT(target_probs);
|
||||
auto device = target_probs.device();
|
||||
CHECK_EQ(candidates.device(), device);
|
||||
CHECK_EQ(retrive_index.device(), device);
|
||||
CHECK_EQ(retrive_next_token.device(), device);
|
||||
CHECK_EQ(retrive_next_sibling.device(), device);
|
||||
CHECK_EQ(uniform_samples.device(), device);
|
||||
CHECK_EQ(target_probs.device(), device);
|
||||
CHECK_DIM(1, predicts);
|
||||
CHECK_DIM(2, accept_index);
|
||||
CHECK_DIM(1, accept_token_num);
|
||||
CHECK_DIM(2, candidates);
|
||||
CHECK_DIM(2, retrive_index);
|
||||
CHECK_DIM(2, retrive_next_token);
|
||||
CHECK_DIM(2, retrive_next_sibling);
|
||||
CHECK_DIM(2, uniform_samples);
|
||||
CHECK_DIM(3, target_probs);
|
||||
CHECK_DIM(3, draft_probs);
|
||||
unsigned int batch_size = uniform_samples.size(0);
|
||||
unsigned int num_spec_step = accept_index.size(1);
|
||||
unsigned int num_draft_tokens = candidates.size(1);
|
||||
unsigned int vocab_size = target_probs.size(2);
|
||||
CHECK_EQ(batch_size, candidates.size(0));
|
||||
CHECK_EQ(batch_size, retrive_index.size(0));
|
||||
CHECK_EQ(batch_size, retrive_next_token.size(0));
|
||||
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
|
||||
CHECK_EQ(batch_size, target_probs.size(0));
|
||||
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
|
||||
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
|
||||
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
|
||||
CHECK_EQ(num_draft_tokens, uniform_samples.size(1));
|
||||
CHECK_EQ(num_draft_tokens, target_probs.size(1));
|
||||
CHECK_EQ(vocab_size, target_probs.size(2));
|
||||
CHECK_EQ(batch_size, accept_index.size(0));
|
||||
CHECK_EQ(batch_size, accept_token_num.size(0));
|
||||
if (predicts.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
|
||||
}
|
||||
if (accept_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
|
||||
}
|
||||
if (accept_token_num.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||
}
|
||||
if (candidates.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
||||
}
|
||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
||||
}
|
||||
if (uniform_samples.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
|
||||
}
|
||||
if (target_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
if (draft_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||
static_cast<int*>(predicts.data_ptr()),
|
||||
static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()),
|
||||
static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()),
|
||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
||||
static_cast<float*>(uniform_samples.data_ptr()),
|
||||
static_cast<float*>(target_probs.data_ptr()),
|
||||
static_cast<float*>(draft_probs.data_ptr()),
|
||||
batch_size,
|
||||
num_spec_step,
|
||||
num_draft_tokens,
|
||||
vocab_size,
|
||||
deterministic,
|
||||
stream);
|
||||
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
215
sgl-kernel/csrc/speculative/speculative_sampling.cuh
Normal file
215
sgl-kernel/csrc/speculative/speculative_sampling.cuh
Normal file
@@ -0,0 +1,215 @@
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
* Copyright (c) 2024-2025 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef SPECULATIVE_SAMPLING_CUH_
|
||||
#define SPECULATIVE_SAMPLING_CUH_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <flashinfer/sampling.cuh>
|
||||
|
||||
namespace flashinfer {
|
||||
|
||||
namespace sampling {
|
||||
|
||||
using namespace cub;
|
||||
|
||||
template <
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts,
|
||||
IdType* accept_index,
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
||||
|
||||
DType prob_acc = 0.0;
|
||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||
DType coin = uniform_samples[bx * num_draft_tokens];
|
||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||
uint32_t num_accepted_tokens = 0;
|
||||
IdType cur_index = 0;
|
||||
|
||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||
while (cur_index != -1) {
|
||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
prob_acc += target_probs[cur_prob_offset + draft_token_id];
|
||||
|
||||
if (coin < prob_acc) {
|
||||
// accept token
|
||||
prob_acc = 0.;
|
||||
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
||||
coin = uniform_samples[bx * num_draft_tokens + cur_index];
|
||||
predicts[last_accepted_retrive_idx] = draft_token_id;
|
||||
++num_accepted_tokens;
|
||||
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
|
||||
last_accepted_retrive_idx = draft_index;
|
||||
break;
|
||||
} else {
|
||||
// FIXME: leverage draft probs
|
||||
draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id];
|
||||
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
|
||||
}
|
||||
}
|
||||
if (cur_index == -1) break;
|
||||
}
|
||||
accept_token_num[bx] = num_accepted_tokens;
|
||||
|
||||
// sample from relu(target_probs - draft_probs)
|
||||
DType sum_relu_q_minus_p(0);
|
||||
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
||||
DType relu_q_minus_p[VEC_SIZE];
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
q_vec.fill(DType(0));
|
||||
p_vec.fill(DType(0));
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
||||
// there is no draft_probs for the bonus token
|
||||
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||
}
|
||||
sum_relu_q_minus_p += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(relu_q_minus_p);
|
||||
__syncthreads();
|
||||
}
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = sum_relu_q_minus_p;
|
||||
}
|
||||
// init the first rejected token to (d - 1)
|
||||
temp_storage.sampled_id = d - 1;
|
||||
__syncthreads();
|
||||
sum_relu_q_minus_p = temp_storage.block_aggregate.value;
|
||||
DType u = coin * sum_relu_q_minus_p;
|
||||
|
||||
DType aggregate_relu_q_minus_p(0);
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
q_vec.fill(DType(0));
|
||||
p_vec.fill(DType(0));
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
||||
// there is no draft_probs for the bonus token
|
||||
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
|
||||
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
||||
if (aggregate_relu_q_minus_p > u) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// set the first rejected token
|
||||
predicts[last_accepted_retrive_idx] = temp_storage.sampled_id;
|
||||
// value at not used indices are undefined
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts,
|
||||
IdType* output_token_ids,
|
||||
IdType* output_accepted_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {
|
||||
&predicts,
|
||||
&output_token_ids,
|
||||
&output_accepted_token_num,
|
||||
&candidates,
|
||||
&retrive_index,
|
||||
&retrive_next_token,
|
||||
&retrive_next_sibling,
|
||||
&uniform_samples,
|
||||
&target_probs,
|
||||
&draft_probs,
|
||||
&batch_size,
|
||||
&num_speculative_tokens,
|
||||
&num_draft_tokens,
|
||||
&d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // SPECULATIVE_SAMPLING_CUH_
|
||||
179
sgl-kernel/csrc/torch_extension.cc
Normal file
179
sgl-kernel/csrc/torch_extension.cc
Normal file
@@ -0,0 +1,179 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def(
|
||||
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||
"barrier_in, int[] barrier_out) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("dispose", &dispose);
|
||||
|
||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
|
||||
|
||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
|
||||
"Tensor");
|
||||
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
|
||||
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
|
||||
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
|
||||
|
||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
|
||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
m.def(
|
||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
||||
"bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
|
||||
"retrive_next_sibling, "
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
m.def(
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
66
sgl-kernel/csrc/torch_extension_rocm.cc
Normal file
66
sgl-kernel/csrc/torch_extension_rocm.cc
Normal file
@@ -0,0 +1,66 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
m.def(
|
||||
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||
"()");
|
||||
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
|
||||
m.def("dispose", &dispose);
|
||||
|
||||
m.def("meta_size", &meta_size);
|
||||
|
||||
m.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
m.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
m.def("allocate_meta_buffer", &allocate_meta_buffer);
|
||||
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
|
||||
|
||||
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
|
||||
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
Reference in New Issue
Block a user