68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import torch
|
|
from sgl_kernel import moe_align_block_size
|
|
|
|
|
|
def test_moe_align_block_size():
|
|
# For DeepSeek V3, we have 256 experts
|
|
num_experts = 256
|
|
|
|
# Test different combinations of block_size, num_tokens and topk
|
|
for block_size in [32, 64, 128, 256]:
|
|
print(f"\nTesting block_size={block_size}")
|
|
for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
|
for topk in [1, 2, 4, 8, 16, 32, 64]:
|
|
print(
|
|
f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
|
)
|
|
|
|
# Create random topk_ids with shape [num_tokens, topk]
|
|
topk_ids = torch.randint(
|
|
0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (
|
|
block_size - 1
|
|
)
|
|
sorted_ids = torch.empty(
|
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
sorted_ids.fill_(topk_ids.numel())
|
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
|
expert_ids = torch.empty(
|
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
num_tokens_post_pad = torch.empty(
|
|
(1), dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
|
|
token_cnts_buffer = torch.empty(
|
|
(num_experts + 1) * num_experts,
|
|
dtype=torch.int32,
|
|
device=topk_ids.device,
|
|
)
|
|
cumsum_buffer = torch.empty(
|
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
|
|
try:
|
|
moe_align_block_size(
|
|
topk_ids,
|
|
num_experts,
|
|
block_size,
|
|
sorted_ids,
|
|
expert_ids,
|
|
num_tokens_post_pad,
|
|
token_cnts_buffer,
|
|
cumsum_buffer,
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
|
)
|
|
print(f"Error message: {str(e)}")
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_moe_align_block_size()
|