Sampling penalties memory interface (#2870)
This commit is contained in:
@@ -3,38 +3,65 @@ from sgl_kernel import moe_align_block_size
|
||||
|
||||
|
||||
def test_moe_align_block_size():
|
||||
# For DeepSeek V3, we have 256 experts
|
||||
num_experts = 256
|
||||
block_size = 128
|
||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
# Test different combinations of block_size, num_tokens and topk
|
||||
for block_size in [32, 64, 128, 256]:
|
||||
print(f"\nTesting block_size={block_size}")
|
||||
for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
for topk in [1, 2, 4, 8, 16, 32, 64]:
|
||||
print(
|
||||
f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
||||
)
|
||||
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
# Create random topk_ids with shape [num_tokens, topk]
|
||||
topk_ids = torch.randint(
|
||||
0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (
|
||||
block_size - 1
|
||||
)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
try:
|
||||
moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
||||
)
|
||||
print(f"Error message: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
test_moe_align_block_size()
|
||||
if __name__ == "__main__":
|
||||
test_moe_align_block_size()
|
||||
|
||||
Reference in New Issue
Block a user