[BUG] fix moe benchmark when bs*seq is small (#3382)
This commit is contained in:
committed by
GitHub
parent
4530136e61
commit
64480df495
@@ -157,7 +157,7 @@ def calculate_diff(batch_size, seq_len):
|
||||
)
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids_cuda = torch.empty(
|
||||
expert_ids_cuda = torch.zeros(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
@@ -172,7 +172,7 @@ def calculate_diff(batch_size, seq_len):
|
||||
|
||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_triton.fill_(topk_ids.numel())
|
||||
expert_ids_triton = torch.empty_like(expert_ids_cuda)
|
||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
# compare the performance of cuda and triton implementation
|
||||
|
||||
Reference in New Issue
Block a user