Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)
This commit is contained in:
@@ -21,8 +21,17 @@ limitations under the License.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T, int N, int Alignment = sizeof(T) * N>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
public:
|
||||
T data[N];
|
||||
};
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
#define VEC_SIZE 4
|
||||
using Vec = AlignedArray<int32_t, VEC_SIZE>;
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t experts_per_warp,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum) {
|
||||
int32_t* __restrict__ cumsum,
|
||||
bool pad_sorted_token_ids) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel(
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
if (pad_sorted_token_ids) {
|
||||
int32_t fill_val = static_cast<int32_t>(numel);
|
||||
int32_t total = *total_tokens_post_pad;
|
||||
|
||||
Vec fill_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
fill_vec.data[i] = fill_val;
|
||||
}
|
||||
|
||||
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
|
||||
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||
|
||||
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
|
||||
out_ptr[idx] = fill_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
size_t numel,
|
||||
bool pad_sorted_token_ids) {
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
if (pad_sorted_token_ids) {
|
||||
int32_t fill_val = static_cast<int32_t>(numel);
|
||||
int32_t total = *total_tokens_post_pad;
|
||||
|
||||
Vec fill_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
fill_vec.data[i] = fill_val;
|
||||
}
|
||||
|
||||
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
|
||||
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||
|
||||
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
|
||||
out_ptr[idx] = fill_vec;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||
@@ -165,7 +214,8 @@ void moe_align_block_size(
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer) {
|
||||
torch::Tensor cumsum_buffer,
|
||||
bool pad_sorted_token_ids) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
@@ -190,7 +240,8 @@ void moe_align_block_size(
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel());
|
||||
topk_ids.numel(),
|
||||
pad_sorted_token_ids);
|
||||
} else {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
@@ -207,7 +258,8 @@ void moe_align_block_size(
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
pad_sorted_token_ids);
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
|
||||
Reference in New Issue
Block a user