Files
sglang/sgl-kernel/tests/test_moe_align.py
2025-01-13 23:09:00 +08:00

68 lines
2.5 KiB
Python

import torch
from sgl_kernel import moe_align_block_size
def test_moe_align_block_size():
# For DeepSeek V3, we have 256 experts
num_experts = 256
# Test different combinations of block_size, num_tokens and topk
for block_size in [32, 64, 128, 256]:
print(f"\nTesting block_size={block_size}")
for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
for topk in [1, 2, 4, 8, 16, 32, 64]:
print(
f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
)
# Create random topk_ids with shape [num_tokens, topk]
topk_ids = torch.randint(
0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda"
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (
block_size - 1
)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
try:
moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
)
except Exception as e:
print(
f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
)
print(f"Error message: {str(e)}")
raise e
if __name__ == "__main__":
test_moe_align_block_size()