reduce moe_align_block_size_kernel small batch mode overhead (#5086)
This commit is contained in:
@@ -702,7 +702,7 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
token_cnts_buffer = torch.zeros(
|
token_cnts_buffer = torch.empty(
|
||||||
(num_experts + 1) * num_experts,
|
(num_experts + 1) * num_experts,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device,
|
device=topk_ids.device,
|
||||||
|
|||||||
@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
|
|
||||||
|
|
||||||
# Test range
|
# Test range
|
||||||
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
num_experts_range = [8, 32, 64, 128, 256]
|
num_experts_range = [8, 32, 64, 128, 256]
|
||||||
topk_range = [2, 4, 8]
|
topk_range = [1, 2, 4, 8]
|
||||||
|
|
||||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||||
|
|
||||||
@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
token_cnts_buffer = torch.zeros(
|
|
||||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.zeros(
|
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "sgl":
|
if provider == "sgl":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: sgl_moe_align_block_size(
|
def sgl_moe_align_block_size_with_empty(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
):
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
num_tokens_post_pad.clone(),
|
num_tokens_post_pad.clone(),
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: sgl_moe_align_block_size_with_empty(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
const size_t tid = threadIdx.x;
|
||||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
const size_t stride = blockDim.x;
|
||||||
|
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int expert_id = topk_ids[i];
|
int expert_id = topk_ids[i];
|
||||||
int warp_idx = expert_id / experts_per_warp;
|
int warp_idx = expert_id / experts_per_warp;
|
||||||
int expert_offset = expert_id % experts_per_warp;
|
int expert_offset = expert_id % experts_per_warp;
|
||||||
@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||||
|
const scalar_t* __restrict__ topk_ids,
|
||||||
|
int32_t* __restrict__ sorted_token_ids,
|
||||||
|
int32_t* __restrict__ expert_ids,
|
||||||
|
int32_t* __restrict__ total_tokens_post_pad,
|
||||||
|
int32_t num_experts,
|
||||||
|
int32_t block_size,
|
||||||
|
size_t numel) {
|
||||||
|
const size_t tid = threadIdx.x;
|
||||||
|
const size_t stride = blockDim.x;
|
||||||
|
|
||||||
|
extern __shared__ int32_t shared_mem[];
|
||||||
|
int32_t* cumsum = shared_mem;
|
||||||
|
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
|
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
tokens_cnts[threadIdx.x] = 0;
|
||||||
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
|
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
|
||||||
|
}
|
||||||
|
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
|
int32_t expert_id = topk_ids[i];
|
||||||
|
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
|
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(
|
||||||
torch::Tensor topk_ids,
|
torch::Tensor topk_ids,
|
||||||
int64_t num_experts,
|
int64_t num_experts,
|
||||||
@@ -111,50 +170,58 @@ void moe_align_block_size(
|
|||||||
|
|
||||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
|
||||||
int experts_per_warp;
|
int experts_per_warp = WARP_SIZE;
|
||||||
int threads;
|
int threads = 1024;
|
||||||
|
|
||||||
if (num_experts <= 8) {
|
|
||||||
experts_per_warp = 8;
|
|
||||||
threads = 256;
|
|
||||||
} else if (num_experts <= 16) {
|
|
||||||
experts_per_warp = 16;
|
|
||||||
threads = 512;
|
|
||||||
} else {
|
|
||||||
experts_per_warp = WARP_SIZE;
|
|
||||||
threads = 1024;
|
|
||||||
}
|
|
||||||
|
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
|
||||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||||
|
|
||||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
if (small_batch_expert_mode) {
|
||||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
|
||||||
|
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||||
|
|
||||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_experts,
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
padded_num_experts,
|
num_experts,
|
||||||
experts_per_warp,
|
block_size,
|
||||||
block_size,
|
topk_ids.numel());
|
||||||
topk_ids.numel(),
|
} else {
|
||||||
cumsum_buffer.data_ptr<int32_t>());
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
|
|
||||||
const int block_threads = std::min(256, (int)threads);
|
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||||
const int max_blocks = 65535;
|
|
||||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
|
||||||
|
|
||||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
cumsum_buffer.zero_();
|
||||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
cumsum_buffer.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
topk_ids.numel());
|
experts_ids.data_ptr<int32_t>(),
|
||||||
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
|
num_experts,
|
||||||
|
padded_num_experts,
|
||||||
|
experts_per_warp,
|
||||||
|
block_size,
|
||||||
|
topk_ids.numel(),
|
||||||
|
cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
|
||||||
|
const int block_threads = std::min(256, (int)threads);
|
||||||
|
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||||
|
const int max_blocks = 65535;
|
||||||
|
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||||
|
|
||||||
|
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||||
|
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||||
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
cumsum_buffer.data_ptr<int32_t>(),
|
||||||
|
topk_ids.numel());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,7 +151,6 @@ def moe_align_block_size_triton(
|
|||||||
def test_moe_align_block_size_compare_implementations(
|
def test_moe_align_block_size_compare_implementations(
|
||||||
block_size, num_tokens, topk, num_experts
|
block_size, num_tokens, topk, num_experts
|
||||||
):
|
):
|
||||||
# For DeepSeek V3, we have 256 experts
|
|
||||||
|
|
||||||
topk_ids = torch.stack(
|
topk_ids = torch.stack(
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user