fix moe_align_block_size (#2615)
This commit is contained in:
@@ -18,8 +18,22 @@ def test_moe_align_block_size():
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user