fix moe_align_block_size (#2615)

This commit is contained in:
HandH1998
2024-12-27 23:32:53 +08:00
committed by GitHub
parent 70dc2fbe2d
commit 77d1210b36
4 changed files with 24 additions and 18 deletions

View File

@@ -18,8 +18,22 @@ def test_moe_align_block_size():
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
)