update sgl-kernel for EP: kernel part (#8514)
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Ke Bao <ispobaoke@gmail.com>
This commit is contained in:
@@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations(
|
||||
:, :topk
|
||||
]
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
||||
|
||||
sorted_ids_cuda = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
@@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations(
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
num_experts + 2, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||
@@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations(
|
||||
|
||||
moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
num_experts + 1,
|
||||
block_size,
|
||||
sorted_ids_cuda,
|
||||
expert_ids_cuda,
|
||||
num_tokens_post_pad_cuda,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
num_experts + 1,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
|
||||
Reference in New Issue
Block a user