Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)

This commit is contained in:
Ke Bao
2025-06-25 08:44:27 +08:00
committed by GitHub
parent 112b496a6c
commit 57ab776910
7 changed files with 163 additions and 70 deletions

View File

@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
@pytest.mark.parametrize(
"block_size,num_tokens,topk,num_experts",
"block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
list(
itertools.product(
[32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts
[True, False], # pad_sorted_token_ids
)
),
)
def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts
block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
):
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
]
)
topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
:, :topk
]
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids_cuda.fill_(topk_ids.numel())
if not pad_sorted_token_ids:
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_cuda,
token_cnts_buffer,
cumsum_buffer,
pad_sorted_token_ids,
)
moe_align_block_size_triton(
@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_triton,
)
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
f"Expert IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA expert_ids: {expert_ids_cuda}\n"
f"Triton expert_ids: {expert_ids_triton}"
)
assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
assert torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
), (
f"Num tokens post pad mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
)
# Select an expert to check
expert_idx = expert_ids_cuda.max().item()
# Get the first and last block id where expert_ids_cuda == expert_idx
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
block_sorted_start = matching_indices[0].item() * block_size
block_sorted_end = min(
(matching_indices[-1].item() + 1) * block_size, max_num_tokens_padded
)
selected_sorted_ids_cuda = sorted_ids_cuda[
block_sorted_start:block_sorted_end
].sort()[0]
selected_sorted_ids_triton = sorted_ids_triton[
block_sorted_start:block_sorted_end
].sort()[0]
assert torch.allclose(
selected_sorted_ids_cuda,
selected_sorted_ids_triton,
atol=0,
rtol=0,
), (
f"Sorted IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
f"Triton sorted_ids: {selected_sorted_ids_triton}"
)
if __name__ == "__main__":
pytest.main([__file__])