[BUG] fix moe benchmark when bs*seq is small (#3382)
This commit is contained in:
committed by
GitHub
parent
4530136e61
commit
64480df495
@@ -157,7 +157,7 @@ def calculate_diff(batch_size, seq_len):
|
|||||||
)
|
)
|
||||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||||
expert_ids_cuda = torch.empty(
|
expert_ids_cuda = torch.zeros(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad_cuda = torch.empty(
|
num_tokens_post_pad_cuda = torch.empty(
|
||||||
@@ -172,7 +172,7 @@ def calculate_diff(batch_size, seq_len):
|
|||||||
|
|
||||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||||
sorted_ids_triton.fill_(topk_ids.numel())
|
sorted_ids_triton.fill_(topk_ids.numel())
|
||||||
expert_ids_triton = torch.empty_like(expert_ids_cuda)
|
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||||
|
|
||||||
# compare the performance of cuda and triton implementation
|
# compare the performance of cuda and triton implementation
|
||||||
|
|||||||
Reference in New Issue
Block a user