Remove cumsum_buffer initilization (#7439)
This commit is contained in:
@@ -750,9 +750,11 @@ def moe_align_block_size(
|
|||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
|
sorted_ids = torch.empty(
|
||||||
max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty(
|
expert_ids = torch.empty(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
@@ -768,6 +770,9 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
token_cnts_buffer = torch.empty(
|
token_cnts_buffer = torch.empty(
|
||||||
(num_experts + 1) * num_experts,
|
(num_experts + 1) * num_experts,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
|
|||||||
Reference in New Issue
Block a user