optimize moe_align_kernel cuda (#3347)

This commit is contained in:
Xiaoyu Zhang
2025-02-07 00:53:46 +08:00
committed by GitHub
parent adeee15204
commit cdae77b03d
3 changed files with 29 additions and 21 deletions

View File

@@ -163,10 +163,10 @@ def calculate_diff(batch_size, seq_len):
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.empty(
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.empty(
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
@@ -260,10 +260,10 @@ def benchmark(batch_size, seq_len, provider):
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.empty(
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.empty(
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)