New clang format for sgl kernel (#4194)
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
cp ../README.md ../LICENSE .
|
||||
rm -rf dist
|
||||
python3 -m build
|
||||
python3 -m twine upload dist/*
|
||||
|
||||
rm -rf README.md LICENSE
|
||||
@@ -6,3 +6,10 @@ DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
NamespaceIndentation: None
|
||||
SortIncludes: true
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
BinPackParameters: false # Prevents packing parameters in declarations
|
||||
BinPackArguments: false # Prevents packing arguments in function calls
|
||||
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
|
||||
AlignOperands: Align # Aligns arguments vertically
|
||||
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
|
||||
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
|
||||
|
||||
@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
|
||||
// support float16, bfloat16 and float32
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = norm::FusedAddRMSNorm(
|
||||
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
|
||||
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
static_cast<c_type*>(input.data_ptr()),
|
||||
static_cast<c_type*>(residual.data_ptr()),
|
||||
static_cast<c_type*>(weight.data_ptr()),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
eps,
|
||||
torch_current_stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals& sg,
|
||||
DINLINE void start_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
||||
flag)
|
||||
@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals& sg,
|
||||
DINLINE void end_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM);
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag)
|
||||
while (__scoped_atomic_load_n(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
@@ -357,8 +372,14 @@ class CustomAllreduce {
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets, int rank, bool full_nvlink = true)
|
||||
CustomAllreduce(
|
||||
Signal* meta,
|
||||
void* rank_data,
|
||||
size_t rank_data_sz,
|
||||
const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
@@ -382,8 +403,8 @@ class CustomAllreduce {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle),
|
||||
hipIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(hipIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
@@ -399,13 +420,14 @@ class CustomAllreduce {
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (hipPointerGetAttribute(&base_ptr,
|
||||
if (hipPointerGetAttribute(
|
||||
&base_ptr,
|
||||
#ifdef USE_ROCM
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#else
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#endif
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
@@ -415,8 +437,8 @@ class CustomAllreduce {
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error("Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
|
||||
@@ -443,8 +465,8 @@ class CustomAllreduce {
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
@@ -474,11 +496,17 @@ class CustomAllreduce {
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(hipStream_t stream, T* input, T* output, int size,
|
||||
void allreduce(
|
||||
hipStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int size,
|
||||
#ifndef USE_ROCM
|
||||
int threads = 512, int block_limit = 36){
|
||||
int threads = 512,
|
||||
int block_limit = 36){
|
||||
#else
|
||||
int threads = 512, int block_limit = 16) {
|
||||
int threads = 512,
|
||||
int block_limit = 16) {
|
||||
#endif
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
@@ -487,8 +515,8 @@ class CustomAllreduce {
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
hipStreamCaptureStatus status;
|
||||
@@ -499,17 +527,17 @@ class CustomAllreduce {
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = ::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \
|
||||
size);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL( \
|
||||
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
|
||||
@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
|
||||
size_t const world_size, int const tidx, int const bidx) {
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx) {
|
||||
// After this function, at least one block in each GPU has reached the barrier
|
||||
if (tidx < world_size) {
|
||||
// we can think of signals having the shape [world_size, world_size]
|
||||
@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
|
||||
}
|
||||
|
||||
template <bool start, bool need_fence = false>
|
||||
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
|
||||
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
|
||||
__inline__ __device__ void block_barrier(
|
||||
uint32_t** signals,
|
||||
uint32_t const flag,
|
||||
size_t const local_rank,
|
||||
size_t const world_size,
|
||||
int const tidx,
|
||||
int const bidx,
|
||||
int const grid_size) {
|
||||
if constexpr (!start) {
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
|
||||
}
|
||||
}
|
||||
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
|
||||
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
|
||||
grid_size);
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
|
||||
@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
|
||||
}
|
||||
}
|
||||
}
|
||||
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
|
||||
grid_size);
|
||||
block_barrier<true>(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
|
||||
}
|
||||
}
|
||||
|
||||
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
|
||||
bidx, grid_size);
|
||||
block_barrier<false, true>(
|
||||
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
|
||||
|
||||
// Gather all needed elts from other intra-node ranks
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
|
||||
@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
|
||||
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
|
||||
cudaStream_t stream) {
|
||||
void dispatchARKernels(
|
||||
AllReduceStrategyType algo,
|
||||
AllReduceParams& param,
|
||||
int blocks_per_grid,
|
||||
int threads_per_block,
|
||||
cudaStream_t stream) {
|
||||
switch (algo) {
|
||||
case AllReduceStrategyType::ONESHOT: {
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
|
||||
CHECK_CUDA_SUCCESS(cudaGetLastError());
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
|
||||
cudaStream_t stream) {
|
||||
void trtCustomAllReduce(
|
||||
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
|
||||
if (params.elts_total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
|
||||
class AllReduceMeta {
|
||||
public:
|
||||
AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
AllReduceMeta(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
this->rank_id = (int)rank_id;
|
||||
this->world_size = (int)world_size;
|
||||
this->barrier_in = barrier_in;
|
||||
@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
|
||||
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
|
||||
}
|
||||
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
fptr_t init_custom_ar(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out) {
|
||||
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
|
||||
return (fptr_t)m;
|
||||
}
|
||||
@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
|
||||
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
|
||||
std::vector<std::string> handle_bytes;
|
||||
handle_bytes.reserve(handles.size());
|
||||
|
||||
@@ -23,15 +23,18 @@ limitations under the License.
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
|
||||
const T* __restrict__ k, // [b, h, 1, d]
|
||||
const T* __restrict__ v, // [b, h, 1, e]
|
||||
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||
const float* __restrict__ slope, // [h, 1, 1]
|
||||
T* __restrict__ output, // [b, h, 1, e]
|
||||
float* __restrict__ new_kv, // [b, h, d, e]
|
||||
const int batch_size, const int num_heads, const int qk_dim,
|
||||
const int v_dim) {
|
||||
__global__ void lightning_attention_decode_kernel(
|
||||
const T* __restrict__ q, // [b, h, 1, d]
|
||||
const T* __restrict__ k, // [b, h, 1, d]
|
||||
const T* __restrict__ v, // [b, h, 1, e]
|
||||
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||
const float* __restrict__ slope, // [h, 1, 1]
|
||||
T* __restrict__ output, // [b, h, 1, e]
|
||||
float* __restrict__ new_kv, // [b, h, d, e]
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int qk_dim,
|
||||
const int v_dim) {
|
||||
extern __shared__ char smem[];
|
||||
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
|
||||
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
||||
}
|
||||
}
|
||||
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv) {
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv) {
|
||||
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
|
||||
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
|
||||
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
|
||||
@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
|
||||
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
|
||||
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
|
||||
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
|
||||
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
|
||||
qk_dim, v_dim);
|
||||
q.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(),
|
||||
v.data_ptr<scalar_t>(),
|
||||
past_kv.data_ptr<float>(),
|
||||
slope.data_ptr<float>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
new_kv.data_ptr<float>(),
|
||||
batch_size,
|
||||
num_heads,
|
||||
qk_dim,
|
||||
v_dim);
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -25,9 +25,15 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename ScaleTileIterator_,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol {
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(batch_stride_alpha_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant,
|
||||
bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
EpilogueVisitorPerRowPerCol(
|
||||
Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D,
|
||||
bool with_bias,
|
||||
bool per_token_quant,
|
||||
bool per_channel_quant,
|
||||
AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col,
|
||||
typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
|
||||
@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>, class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(cute::is_same_v<KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
static_assert(
|
||||
cute::
|
||||
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
CUTLASS_TRACE_HOST(
|
||||
" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount, smem_size);
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
|
||||
}
|
||||
} else {
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount, 0);
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
CUTLASS_TRACE_HOST(
|
||||
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
|
||||
<< ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
|
||||
@@ -32,10 +32,11 @@ namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor {
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
|
||||
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRefA ref_A_,
|
||||
TensorRefB ref_B_,
|
||||
TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_,
|
||||
TensorRefC ref_C_,
|
||||
TensorRefC ref_D_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(GemmUniversalMode::kGemm),
|
||||
problem_size(problem_size_),
|
||||
batch_count(1),
|
||||
@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
} else if (
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
} else if (
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
} else if (
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx,
|
||||
tb_offset_A);
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx,
|
||||
tb_offset_B);
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
|
||||
with_bias = false;
|
||||
}
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(),
|
||||
thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col,
|
||||
params.ptr_C, params.ptr_D, threadblock_offset,
|
||||
blockIdx.y * params.problem_size.m());
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
params.params_alpha_col,
|
||||
params.params_C,
|
||||
params.params_D,
|
||||
with_bias,
|
||||
true,
|
||||
true,
|
||||
params.ptr_alpha_row,
|
||||
params.ptr_alpha_col,
|
||||
params.ptr_C,
|
||||
params.ptr_D,
|
||||
threadblock_offset,
|
||||
blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
|
||||
@@ -21,10 +21,13 @@
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
static void check_group_count(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs) {
|
||||
TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
|
||||
"The group count of inputs, weights and outputs should be the same.");
|
||||
static void check_group_count(
|
||||
const std::vector<torch::Tensor>& inputs,
|
||||
const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs) {
|
||||
TORCH_CHECK(
|
||||
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
|
||||
"The group count of inputs, weights and outputs should be the same.");
|
||||
}
|
||||
|
||||
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
|
||||
@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
|
||||
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
|
||||
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
|
||||
TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice,
|
||||
stream) == CUBLAS_STATUS_SUCCESS);
|
||||
TORCH_CHECK(
|
||||
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
|
||||
CUBLAS_STATUS_SUCCESS);
|
||||
return gpu_ptrs;
|
||||
}
|
||||
|
||||
// We want compute input @ weight^T in row major
|
||||
// This is equivalent to computing weight @ input^T in col major
|
||||
// Cublas only accepts matrix in column major, so this arrangement is needed
|
||||
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
|
||||
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
|
||||
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
|
||||
const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) {
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
|
||||
"cublas grouped_gemm can"
|
||||
"only be applied to float16 and bfloat16 dtype");
|
||||
void cublas_grouped_gemm(
|
||||
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
|
||||
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
|
||||
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
|
||||
const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream) {
|
||||
TORCH_CHECK(
|
||||
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
|
||||
"cublas grouped_gemm can"
|
||||
"only be applied to float16 and bfloat16 dtype");
|
||||
|
||||
int group_count = inputs.size();
|
||||
check_group_count(inputs, weights, outputs);
|
||||
@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
|
||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(),
|
||||
n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(),
|
||||
cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type,
|
||||
ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type,
|
||||
ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F);
|
||||
auto status = cublasGemmGroupedBatchedEx(
|
||||
handle,
|
||||
transa_array.data(),
|
||||
transb_array.data(),
|
||||
m_array.data(),
|
||||
n_array.data(),
|
||||
k_array.data(),
|
||||
alpha_array.data(),
|
||||
(void**)d_a.data_ptr(),
|
||||
cuda_data_type,
|
||||
lda_array.data(),
|
||||
(void**)d_b.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldb_array.data(),
|
||||
beta_array.data(),
|
||||
(void**)d_c.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldc_array.data(),
|
||||
group_count,
|
||||
group_size.data(),
|
||||
CUBLAS_COMPUTE_32F);
|
||||
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
|
||||
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
|
||||
}
|
||||
|
||||
@@ -35,8 +35,12 @@
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC,
|
||||
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
|
||||
void sm90_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
|
||||
}
|
||||
|
||||
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype) {
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
|
||||
"mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
|
||||
"mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
|
||||
#endif
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
|
||||
@@ -53,10 +53,17 @@ limitations under the License.
|
||||
using namespace cute;
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
|
||||
typename WarpShape, int Stages, bool WithBias, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
|
||||
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
|
||||
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CtaShape,
|
||||
typename WarpShape,
|
||||
int Stages,
|
||||
bool WithBias,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
|
||||
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
|
||||
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
|
||||
struct DeviceGemmFp8RowwiseSm89 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 {
|
||||
// Number of epilogue stages in EVT
|
||||
static constexpr int EVTEpilogueStages = 1;
|
||||
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<CtaShape, WarpShape, ElementC,
|
||||
AlignmentC, EVTEpilogueStages>;
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::
|
||||
OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
|
||||
|
||||
// Definition of EVT
|
||||
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
|
||||
Stride<_0, _1, _0>>;
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using bScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
|
||||
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
|
||||
|
||||
using ComputeAScale =
|
||||
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
|
||||
Stride<_1, _0, _0>>;
|
||||
using ComputeAScale = cutlass::epilogue::threadblock::
|
||||
VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using aScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
|
||||
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
|
||||
|
||||
// With bias
|
||||
using biasSrc =
|
||||
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
|
||||
using ComputeAScaleWithBias =
|
||||
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiply_add, ElementC, ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add,
|
||||
ElementC,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EpilogueAScaleWithBias =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
|
||||
|
||||
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride<int64_t, _1, _0>>;
|
||||
using EpilogueStore =
|
||||
typename cutlass::platform::conditional<WithBias,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
|
||||
OutputTileThreadMap,
|
||||
ElementC,
|
||||
cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, _1, _0>>;
|
||||
using EpilogueStore = typename cutlass::platform::conditional<
|
||||
WithBias,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
|
||||
|
||||
using EpilogueOp = EpilogueStore;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB,
|
||||
cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator,
|
||||
ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp,
|
||||
ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel;
|
||||
ElementA,
|
||||
LayoutA,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
FP8MathOperator,
|
||||
EVTEpilogueStages>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
typename Gemm::Arguments prepare_sm89_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::
|
||||
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||
|
||||
typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode
|
||||
{m, n, k}, // Problem size
|
||||
1, // Split-k factor
|
||||
{}, // Epilogue args
|
||||
ptr_a, // a pointer
|
||||
ptr_b, // b pointer
|
||||
nullptr, // c pointer (unused)
|
||||
nullptr, // d pointer (unused)
|
||||
m * k, // batch stride a (unused)
|
||||
n * k, // batch stride b (unused)
|
||||
m * n, // batch stride c (unused)
|
||||
m * n, // batch stride d (unused)
|
||||
lda, // stride a
|
||||
ldb, // stride b
|
||||
ldc, // stride c (unused)
|
||||
ldc); // stride d (unused)
|
||||
typename Gemm::Arguments args(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // Mode
|
||||
{m, n, k}, // Problem size
|
||||
1, // Split-k factor
|
||||
{}, // Epilogue args
|
||||
ptr_a, // a pointer
|
||||
ptr_b, // b pointer
|
||||
nullptr, // c pointer (unused)
|
||||
nullptr, // d pointer (unused)
|
||||
m * k, // batch stride a (unused)
|
||||
n * k, // batch stride b (unused)
|
||||
m * n, // batch stride c (unused)
|
||||
m * n, // batch stride d (unused)
|
||||
lda, // stride a
|
||||
ldb, // stride b
|
||||
ldc, // stride c (unused)
|
||||
ldc); // stride d (unused)
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue = {{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
} else {
|
||||
args.epilogue = {{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void launch_sm89_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
|
||||
}
|
||||
|
||||
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
|
||||
void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm89_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
if (bias) {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
|
||||
Stages, true>::Gemm;
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
true>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
|
||||
Stages, false>::Gemm;
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
false>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm89_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const n = out.size(1);
|
||||
|
||||
if (m == 1) {
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 16) {
|
||||
// M in (1, 16]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
// M in (16, 64]
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
// M in (64, 128]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
// M in (128, 256]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>,
|
||||
cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 128>,
|
||||
cutlass::gemm::GemmShape<64, 32, 128>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 512) {
|
||||
// M in (256, 512)
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
// M in (512, inf)
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
3>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
|
||||
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void, bool WithBias = false>
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
bool WithBias = false>
|
||||
struct DeviceGemmFp8RowwiseSm90 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
|
||||
// setting in the Collective Builder
|
||||
// Implement rowwise scaling epilogue.
|
||||
using XScale =
|
||||
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using WScale =
|
||||
cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementOutput,
|
||||
ElementOutput,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
|
||||
ElementComputeEpilogue, // First stage output type.
|
||||
ElementComputeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue, // First stage output type.
|
||||
ElementComputeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput,
|
||||
ElementComputeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias =
|
||||
cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC,
|
||||
AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized,
|
||||
cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
AlignmentOutput,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 {
|
||||
using FastAccum = FastPongSchedule; // Default apply Pingpong
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
typename Gemm::Arguments prepare_sm90_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{ptr_a, stride_a, ptr_b, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
ptr_d,
|
||||
stride_d}};
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{ptr_a, stride_a, ptr_b, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
ptr_d,
|
||||
stride_d}};
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue.thread = {
|
||||
{ptr_scales_a},
|
||||
@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void launch_sm90_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType,
|
||||
typename TileSchedulerType>
|
||||
void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias, bool fast_accum = true,
|
||||
bool use_persistent = false) {
|
||||
template <
|
||||
typename OutType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename TileSchedulerType>
|
||||
void sm90_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias,
|
||||
bool fast_accum = true,
|
||||
bool use_persistent = false) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (bias) {
|
||||
using Gemm =
|
||||
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
|
||||
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm;
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm =
|
||||
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
|
||||
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm;
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm90_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
using BasicTileScheduler = void;
|
||||
if (m <= 1) {
|
||||
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler,
|
||||
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
FastBasicScheduler,
|
||||
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
if (m <= 64) {
|
||||
// m in [1, 64]
|
||||
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
// m in (64, 256]
|
||||
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 1024) {
|
||||
// m in (256, 1024]
|
||||
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
// m in (1024, inf)
|
||||
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
|
||||
"mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
|
||||
"mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
@@ -35,11 +35,20 @@ limitations under the License.
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename ThreadblockShape, typename WarpShape,
|
||||
typename InstructionShape, int NumStages>
|
||||
void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
int NumStages>
|
||||
void cutlass_int8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA,
|
||||
ElementInputB, ElementOutput, ElementCompute>;
|
||||
using DefaultGemmConf = cutlass::gemm::device::
|
||||
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
|
||||
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB,
|
||||
cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
DefaultGemmConf::kAlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
DefaultGemmConf::kAlignmentB,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
true,
|
||||
typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
|
||||
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
|
||||
cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
ElementCompute>;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
|
||||
ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>;
|
||||
ThreadblockShape,
|
||||
GemmKernel_::kThreadCount,
|
||||
AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
EpilogueOutputOp>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
|
||||
@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
|
||||
typename EpilogueOutputOp::Params linearScalingParams;
|
||||
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
|
||||
|
||||
typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0},
|
||||
{a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
typename Gemm::Arguments args{
|
||||
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
|
||||
auto workspace = torch::empty(gemm_op.get_workspace_size(args),
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm75_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
if (m <= 32) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 64) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm80_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128 && n < 8192) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType,
|
||||
bool WithBias>
|
||||
void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
bool WithBias>
|
||||
void cutlass_int8_scaled_mm_sm90(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
@@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
using WScale = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
using Bias = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
// Scale
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Compute0 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Compute1 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput,
|
||||
cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp;
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentOutput,
|
||||
EpilogueScheduleType,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB,
|
||||
cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages,
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
Stages,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
@@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
|
||||
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{a_ptr, stride_a, b_ptr, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
o_ptr,
|
||||
stride_d}};
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{a_ptr, stride_a, b_ptr, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
o_ptr,
|
||||
stride_d}};
|
||||
|
||||
if constexpr (WithBias) {
|
||||
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
@@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
|
||||
};
|
||||
}
|
||||
|
||||
auto workspace = torch::empty(gemm_op.get_workspace_size(args),
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
|
||||
void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm90_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
if (bias) {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
@@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to
|
||||
}
|
||||
|
||||
template <typename ElementOutput>
|
||||
void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
void sm90_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 32) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _128, _128>, Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_1, _4, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _256>, Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _256>,
|
||||
Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 4096) {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_64, _128, _128>, Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
return sm90_dispatch_bias<ElementOutput, Shape<_128, _128, _128>, Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b,
|
||||
bias);
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
|
||||
const int64_t num_elements) {
|
||||
__global__ void
|
||||
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
|
||||
float max_value = 0.0f;
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output,
|
||||
const float* __restrict__ scale, const int64_t num_elements) {
|
||||
__global__ void per_tensor_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t num_elements) {
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
const float scale_val = 1.0f / (*scale);
|
||||
@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
|
||||
}
|
||||
|
||||
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
num_elements);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q,
|
||||
float* __restrict__ output_s, const int group_size,
|
||||
const int num_groups, const float eps, const float fp8_min,
|
||||
const float fp8_max) {
|
||||
__global__ void per_token_group_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
void* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const float eps,
|
||||
const float fp8_min,
|
||||
const float fp8_max) {
|
||||
const int groups_per_block = 16;
|
||||
const int local_group_id = threadIdx.x / 16;
|
||||
const int lane_id = threadIdx.x % 16;
|
||||
@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s,
|
||||
int64_t group_size, double eps, double fp8_min, double fp8_max) {
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), output_q.data_ptr(), static_cast<float*>(output_s.data_ptr()),
|
||||
group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max);
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
output_q.data_ptr(),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
group_size,
|
||||
num_groups,
|
||||
(float)eps,
|
||||
(float)fp8_min,
|
||||
(float)fp8_max);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -7,9 +7,12 @@
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s, const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
|
||||
if (token_idx >= num_tokens) return;
|
||||
@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens);
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -25,9 +25,11 @@ limitations under the License.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer, size_t numel) {
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum) {
|
||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
||||
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
align_kernel<<<1, 1024, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -23,10 +23,18 @@
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
|
||||
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len,
|
||||
bool* tree_mask, int64_t* positions, int64_t* retrive_index,
|
||||
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth,
|
||||
int draft_token_num) {
|
||||
__global__ void build_tree_efficient(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int32_t* verified_seq_len,
|
||||
bool* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int64_t* retrive_next_token,
|
||||
int64_t* retrive_next_sibling,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num) {
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk), int32_t(depth), int32_t(draft_token_num));
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num));
|
||||
}
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
|
||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
||||
// draft_token, depth + 2]
|
||||
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask,
|
||||
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) {
|
||||
__global__ void build_tree(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int32_t* verified_seq_len,
|
||||
bool* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num) {
|
||||
void build_tree_kernel(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num) {
|
||||
// TODO (ying) check shape
|
||||
// TODO (ying) check type
|
||||
int bs = parent_list.size(0);
|
||||
@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
build_tree<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk),
|
||||
int32_t(depth), int32_t(draft_token_num));
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num));
|
||||
}
|
||||
|
||||
@@ -29,12 +29,19 @@ using namespace flashinfer;
|
||||
// retrive_next_sibling: [bs, num_draft_tokens]
|
||||
// uniform_samples: [bs, num_draft_tokens]
|
||||
// target_probs: [bs, num_draft_tokens, vocab_size]
|
||||
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
|
||||
bool deterministic, int64_t cuda_stream = 0) {
|
||||
void tree_speculative_sampling_target_only(
|
||||
at::Tensor predicts,
|
||||
at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream = 0) {
|
||||
CHECK_INPUT(candidates);
|
||||
CHECK_INPUT(retrive_index);
|
||||
CHECK_INPUT(retrive_next_token);
|
||||
@@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||
static_cast<int*>(predicts.data_ptr()), static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()), static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()), static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
|
||||
static_cast<float*>(target_probs.data_ptr()), static_cast<float*>(draft_probs.data_ptr()), batch_size,
|
||||
num_spec_step, num_draft_tokens, vocab_size, deterministic, stream);
|
||||
static_cast<int*>(predicts.data_ptr()),
|
||||
static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()),
|
||||
static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()),
|
||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
||||
static_cast<float*>(uniform_samples.data_ptr()),
|
||||
static_cast<float*>(target_probs.data_ptr()),
|
||||
static_cast<float*>(draft_probs.data_ptr()),
|
||||
batch_size,
|
||||
num_spec_step,
|
||||
num_draft_tokens,
|
||||
vocab_size,
|
||||
deterministic,
|
||||
stream);
|
||||
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
|
||||
@@ -27,15 +27,29 @@ namespace sampling {
|
||||
|
||||
using namespace cub;
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, typename IdType>
|
||||
__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index,
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling, DType* uniform_samples,
|
||||
DType* target_probs, DType* draft_probs, uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens, uint32_t num_draft_tokens,
|
||||
uint32_t d) {
|
||||
template <
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts,
|
||||
IdType* accept_index,
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
@@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids,
|
||||
IdType* output_accepted_token_num, // mutable
|
||||
IdType* candidates, IdType* retrive_index, IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs,
|
||||
DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens, uint32_t d, bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts,
|
||||
IdType* output_token_ids,
|
||||
IdType* output_accepted_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&predicts,
|
||||
&output_token_ids,
|
||||
&output_accepted_token_num,
|
||||
&candidates,
|
||||
&retrive_index,
|
||||
&retrive_next_token,
|
||||
&retrive_next_sibling,
|
||||
&uniform_samples,
|
||||
&target_probs,
|
||||
&draft_probs,
|
||||
&batch_size,
|
||||
&num_speculative_tokens,
|
||||
&num_draft_tokens,
|
||||
&d};
|
||||
void* args[] = {
|
||||
&predicts,
|
||||
&output_token_ids,
|
||||
&output_accepted_token_num,
|
||||
&candidates,
|
||||
&retrive_index,
|
||||
&retrive_next_token,
|
||||
&retrive_next_sibling,
|
||||
&uniform_samples,
|
||||
&target_probs,
|
||||
&draft_probs,
|
||||
&batch_size,
|
||||
&num_speculative_tokens,
|
||||
&num_draft_tokens,
|
||||
&d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TreeSpeculativeSamplingTargetOnly<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, DETERMINISTIC,
|
||||
DType, IdType>;
|
||||
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
|
||||
@@ -42,8 +42,8 @@ using fptr_t = int64_t;
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||
int64_t cuda_stream);
|
||||
void gemma_fused_add_rmsnorm(
|
||||
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
@@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
// ROCM custom allreduce
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
|
||||
fptr_t init_custom_ar(
|
||||
torch::Tensor& meta,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int64_t rank,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets);
|
||||
void register_buffer(
|
||||
fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
#else
|
||||
// TRTLLM custom allreduce
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out);
|
||||
fptr_t init_custom_ar(
|
||||
int64_t rank_id,
|
||||
int64_t world_size,
|
||||
torch::Tensor& rank_data,
|
||||
const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers,
|
||||
const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out);
|
||||
void dispose(fptr_t _fa);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
|
||||
double eps, double fp8_min, double fp8_max);
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max);
|
||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle, int64_t cuda_stream);
|
||||
void cublas_grouped_gemm(
|
||||
const std::vector<torch::Tensor>& inputs,
|
||||
const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs,
|
||||
const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
|
||||
bool deterministic = true, int64_t cuda_stream = 0);
|
||||
void tree_speculative_sampling_target_only(
|
||||
at::Tensor predicts,
|
||||
at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
bool deterministic = true,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num);
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num);
|
||||
|
||||
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num);
|
||||
void build_tree_kernel(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
||||
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void bmm_fp8(
|
||||
at::Tensor A,
|
||||
at::Tensor B,
|
||||
at::Tensor D,
|
||||
at::Tensor A_scale,
|
||||
at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void min_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor samples,
|
||||
std::optional<at::Tensor> maybe_min_p_arr,
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val, int64_t cuda_stream);
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val,
|
||||
int64_t cuda_stream);
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
inline void top_k_renorm_probs_wrapper(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val, int64_t cuda_stream);
|
||||
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
|
||||
int64_t cuda_stream);
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
int64_t cuda_stream);
|
||||
void top_k_top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor samples,
|
||||
at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void top_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor samples,
|
||||
at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor q_rope,
|
||||
at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// sgl_per_token_quant_fp8
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
|
||||
@@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
|
||||
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
|
||||
cudaStream_t stream);
|
||||
void trtCustomAllReduce(
|
||||
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
|
||||
|
||||
} // namespace trt_llm
|
||||
|
||||
@@ -95,7 +95,6 @@ inline int getSMVersion() {
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
Reference in New Issue
Block a user