feat: support ep size < 32 for sgl kernel (#4348)
This commit is contained in:
@@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t padded_num_experts,
|
||||
int32_t experts_per_warp,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
@@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
if (my_expert_start + i < padded_num_experts) {
|
||||
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||
}
|
||||
}
|
||||
@@ -108,23 +109,44 @@ void moe_align_block_size(
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CHECK(num_experts % WARP_SIZE == 0);
|
||||
int experts_per_warp = num_experts / WARP_SIZE;
|
||||
|
||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
int experts_per_warp;
|
||||
int threads;
|
||||
|
||||
if (num_experts <= 8) {
|
||||
experts_per_warp = 8;
|
||||
threads = 256;
|
||||
} else if (num_experts <= 16) {
|
||||
experts_per_warp = 16;
|
||||
threads = 512;
|
||||
} else {
|
||||
experts_per_warp = WARP_SIZE;
|
||||
threads = 1024;
|
||||
}
|
||||
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
|
||||
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
|
||||
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
padded_num_experts,
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = 256;
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
Reference in New Issue
Block a user