reduce moe_align_block_size_kernel small batch mode overhead (#5086)
This commit is contained in:
@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
|
||||
|
||||
# Test range
|
||||
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_experts_range = [8, 32, 64, 128, 256]
|
||||
topk_range = [2, 4, 8]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
token_cnts_buffer = torch.zeros(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size(
|
||||
|
||||
def sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
):
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
num_tokens_post_pad.clone(),
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user