reduce moe_align_block_size_kernel small batch mode overhead (#5086)
This commit is contained in:
@@ -702,7 +702,7 @@ def moe_align_block_size(
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
token_cnts_buffer = torch.zeros(
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
|
||||
@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
|
||||
|
||||
# Test range
|
||||
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_experts_range = [8, 32, 64, 128, 256]
|
||||
topk_range = [2, 4, 8]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
token_cnts_buffer = torch.zeros(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size(
|
||||
|
||||
def sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
):
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
num_tokens_post_pad.clone(),
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[threadIdx.x] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int64_t num_experts,
|
||||
@@ -111,50 +170,58 @@ void moe_align_block_size(
|
||||
|
||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
int experts_per_warp;
|
||||
int threads;
|
||||
|
||||
if (num_experts <= 8) {
|
||||
experts_per_warp = 8;
|
||||
threads = 256;
|
||||
} else if (num_experts <= 16) {
|
||||
experts_per_warp = 16;
|
||||
threads = 512;
|
||||
} else {
|
||||
experts_per_warp = WARP_SIZE;
|
||||
threads = 1024;
|
||||
}
|
||||
int experts_per_warp = WARP_SIZE;
|
||||
int threads = 1024;
|
||||
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||
if (small_batch_expert_mode) {
|
||||
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
padded_num_experts,
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
|
||||
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel());
|
||||
} else {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
topk_ids.numel());
|
||||
cumsum_buffer.zero_();
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
padded_num_experts,
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(),
|
||||
topk_ids.numel());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -151,7 +151,6 @@ def moe_align_block_size_triton(
|
||||
def test_moe_align_block_size_compare_implementations(
|
||||
block_size, num_tokens, topk, num_experts
|
||||
):
|
||||
# For DeepSeek V3, we have 256 experts
|
||||
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user