[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x) (#3613)

This commit is contained in:
yiakwy-xpu-ml-framework-team
2025-02-28 11:42:48 +08:00
committed by GitHub
parent bc20e93f2d
commit 1c96fa86cf
5 changed files with 384 additions and 97 deletions

View File

@@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.empty(
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)