Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)
This commit is contained in:
@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size,num_tokens,topk,num_experts",
|
||||
"block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
|
||||
list(
|
||||
itertools.product(
|
||||
[32, 64, 128, 256], # block_size
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||
[64, 160, 256, 257, 260, 264], # num_experts
|
||||
[True, False], # pad_sorted_token_ids
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_moe_align_block_size_compare_implementations(
|
||||
block_size, num_tokens, topk, num_experts
|
||||
block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
|
||||
):
|
||||
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
)
|
||||
topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
|
||||
:, :topk
|
||||
]
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
|
||||
sorted_ids_cuda = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
if not pad_sorted_token_ids:
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids_cuda = torch.zeros(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
|
||||
num_tokens_post_pad_cuda,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
moe_align_block_size_triton(
|
||||
@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
|
||||
assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
|
||||
f"Expert IDs mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA expert_ids: {expert_ids_cuda}\n"
|
||||
f"Triton expert_ids: {expert_ids_triton}"
|
||||
)
|
||||
|
||||
assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
|
||||
assert torch.allclose(
|
||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
|
||||
), (
|
||||
f"Num tokens post pad mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
|
||||
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
|
||||
)
|
||||
|
||||
# Select an expert to check
|
||||
expert_idx = expert_ids_cuda.max().item()
|
||||
|
||||
# Get the first and last block id where expert_ids_cuda == expert_idx
|
||||
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
|
||||
block_sorted_start = matching_indices[0].item() * block_size
|
||||
block_sorted_end = min(
|
||||
(matching_indices[-1].item() + 1) * block_size, max_num_tokens_padded
|
||||
)
|
||||
|
||||
selected_sorted_ids_cuda = sorted_ids_cuda[
|
||||
block_sorted_start:block_sorted_end
|
||||
].sort()[0]
|
||||
selected_sorted_ids_triton = sorted_ids_triton[
|
||||
block_sorted_start:block_sorted_end
|
||||
].sort()[0]
|
||||
|
||||
assert torch.allclose(
|
||||
selected_sorted_ids_cuda,
|
||||
selected_sorted_ids_triton,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
), (
|
||||
f"Sorted IDs mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
|
||||
f"Triton sorted_ids: {selected_sorted_ids_triton}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user