255 lines
8.6 KiB
Plaintext
255 lines
8.6 KiB
Plaintext
#include <c10/cuda/CUDAGuard.h>
|
|
#include <cudaTypedefs.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <iostream>
|
|
|
|
#include "cutlass/array.h"
|
|
|
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
|
|
|
__global__ void compute_problem_sizes(
|
|
const int* __restrict__ topk_ids,
|
|
int32_t* problem_sizes1,
|
|
int32_t* problem_sizes2,
|
|
int32_t* atomic_buffer,
|
|
const int64_t topk_length,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
int expert_id = blockIdx.x;
|
|
|
|
int occurrences = 0;
|
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
|
occurrences += (topk_ids[i] == expert_id);
|
|
}
|
|
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
|
__syncthreads();
|
|
|
|
if (threadIdx.x == 0) {
|
|
int final_occurrences = atomic_buffer[expert_id];
|
|
problem_sizes1[expert_id * 3] = final_occurrences;
|
|
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
|
|
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
|
|
problem_sizes2[expert_id * 3] = final_occurrences;
|
|
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
|
|
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
|
|
}
|
|
}
|
|
|
|
__global__ void compute_expert_offsets(
|
|
const int32_t* __restrict__ problem_sizes1,
|
|
int32_t* expert_offsets,
|
|
int32_t* atomic_buffer,
|
|
const int64_t num_experts) {
|
|
int32_t tot_offset = 0;
|
|
expert_offsets[0] = 0;
|
|
for (int i = 0; i < num_experts; ++i) {
|
|
atomic_buffer[i] = tot_offset;
|
|
tot_offset += problem_sizes1[i * 3];
|
|
expert_offsets[i + 1] = tot_offset;
|
|
}
|
|
}
|
|
|
|
__global__ void compute_expert_blockscale_offsets(
|
|
const int32_t* __restrict__ problem_sizes1,
|
|
int32_t* expert_offsets,
|
|
int32_t* blockscale_offsets,
|
|
int32_t* atomic_buffer,
|
|
const int64_t num_experts) {
|
|
int32_t tot_offset = 0;
|
|
int32_t tot_rounded_offset = 0;
|
|
expert_offsets[0] = 0;
|
|
blockscale_offsets[0] = 0;
|
|
for (int i = 0; i < num_experts; ++i) {
|
|
atomic_buffer[i] = tot_offset;
|
|
int num_tokens = problem_sizes1[i * 3];
|
|
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
|
|
tot_offset += num_tokens;
|
|
tot_rounded_offset += rounded_num_tokens;
|
|
expert_offsets[i + 1] = tot_offset;
|
|
blockscale_offsets[i + 1] = tot_rounded_offset;
|
|
}
|
|
}
|
|
|
|
__global__ void compute_arg_sorts(
|
|
const int32_t* __restrict__ topk_ids,
|
|
int32_t* input_permutation,
|
|
int32_t* output_permutation,
|
|
int32_t* atomic_buffer,
|
|
const int64_t topk_length,
|
|
const int64_t topk) {
|
|
int expert_id = blockIdx.x;
|
|
|
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
|
if (topk_ids[i] == expert_id) {
|
|
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
|
input_permutation[start] = i / topk;
|
|
output_permutation[i] = start;
|
|
}
|
|
}
|
|
}
|
|
|
|
void get_moe_prepare_input_caller(
|
|
const torch::Tensor& topk_ids,
|
|
torch::Tensor& expert_offsets,
|
|
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
torch::Tensor& problem_sizes1,
|
|
torch::Tensor& problem_sizes2,
|
|
torch::Tensor& input_permutation,
|
|
torch::Tensor& output_permutation,
|
|
const int64_t num_experts,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
|
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
|
|
|
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
|
|
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
|
|
|
|
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
|
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
topk_ids.numel(),
|
|
n,
|
|
k);
|
|
if (blockscale_offsets.has_value()) {
|
|
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
|
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
num_experts);
|
|
} else {
|
|
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
num_experts);
|
|
}
|
|
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
|
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
|
static_cast<int32_t*>(input_permutation.data_ptr()),
|
|
static_cast<int32_t*>(output_permutation.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
topk_ids.numel(),
|
|
topk_ids.size(1));
|
|
}
|
|
|
|
void prepare_moe_input(
|
|
const torch::Tensor& topk_ids,
|
|
torch::Tensor& expert_offsets,
|
|
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
torch::Tensor& problem_sizes1,
|
|
torch::Tensor& problem_sizes2,
|
|
torch::Tensor& input_permutation,
|
|
torch::Tensor& output_permutation,
|
|
const int64_t num_experts,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
|
|
get_moe_prepare_input_caller(
|
|
topk_ids,
|
|
expert_offsets,
|
|
blockscale_offsets,
|
|
problem_sizes1,
|
|
problem_sizes2,
|
|
input_permutation,
|
|
output_permutation,
|
|
num_experts,
|
|
n,
|
|
k);
|
|
return;
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void shuffleRowsKernel(
|
|
const T* input,
|
|
const int32_t* dst2src_map,
|
|
T* output,
|
|
int64_t num_src_rows,
|
|
int64_t num_dst_rows,
|
|
int64_t num_cols) {
|
|
int64_t dest_row_idx = blockIdx.x;
|
|
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
|
|
|
if (blockIdx.x < num_dst_rows) {
|
|
// Load 128-bits per thread
|
|
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
|
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
|
|
|
// Duplicate and permute rows
|
|
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
|
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
|
|
|
auto const start_offset = threadIdx.x;
|
|
auto const stride = blockDim.x;
|
|
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
|
|
|
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
|
|
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
|
}
|
|
}
|
|
}
|
|
|
|
#define DECLARE_SHUFFLE_ROWS(T) \
|
|
__global__ void shuffleRowsKernel( \
|
|
const T* input, \
|
|
const int32_t* dst2src_map, \
|
|
T* output, \
|
|
int64_t num_src_rows, \
|
|
int64_t num_dest_rows, \
|
|
int64_t num_cols);
|
|
|
|
DECLARE_SHUFFLE_ROWS(float);
|
|
DECLARE_SHUFFLE_ROWS(half);
|
|
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
|
|
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
|
|
DECLARE_SHUFFLE_ROWS(uint8_t);
|
|
|
|
#define SHUFFLE_ROWS(T) \
|
|
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
|
|
reinterpret_cast<const T*>(input), \
|
|
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
|
|
reinterpret_cast<T*>(output), \
|
|
num_src_rows, \
|
|
num_dst_rows, \
|
|
num_cols)
|
|
|
|
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
|
|
case T: \
|
|
SHUFFLE_ROWS(CUDA_T); \
|
|
break;
|
|
|
|
void shuffle_rows_caller(
|
|
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
|
TORCH_CHECK(
|
|
input_tensor.scalar_type() == output_tensor.scalar_type(),
|
|
"Input and output tensors must have the same data type");
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
|
|
uint32_t threads = 256;
|
|
int64_t num_dst_rows = output_tensor.size(0);
|
|
int64_t num_src_rows = input_tensor.size(0);
|
|
int64_t num_cols = input_tensor.size(1);
|
|
const void* input = input_tensor.data_ptr();
|
|
void* output = output_tensor.data_ptr();
|
|
switch (input_tensor.scalar_type()) {
|
|
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
|
|
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
|
|
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
|
|
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
|
|
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
|
|
default:
|
|
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
|
|
}
|
|
return;
|
|
}
|
|
|
|
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
|
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
|
|
return;
|
|
}
|