fix moe align blocks benchmark (#3003)
This commit is contained in:
committed by
GitHub
parent
583697cd71
commit
10bfce71b3
@@ -7,6 +7,8 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel import moe_align_block_size
|
from sgl_kernel import moe_align_block_size
|
||||||
|
|
||||||
|
USE_RANDOM_PERM = False
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
def ceil_div(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
@@ -141,8 +143,13 @@ def moe_align_block_size_triton(
|
|||||||
def calculate_diff(batch_size, seq_len):
|
def calculate_diff(batch_size, seq_len):
|
||||||
num_experts = 256
|
num_experts = 256
|
||||||
block_size = 128
|
block_size = 128
|
||||||
topk_ids = torch.randint(
|
topk = 8
|
||||||
0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
|
|
||||||
|
topk_ids = torch.stack(
|
||||||
|
[
|
||||||
|
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||||
|
for _ in range(batch_size * seq_len)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
@@ -169,7 +176,7 @@ def calculate_diff(batch_size, seq_len):
|
|||||||
expert_ids_triton = torch.empty_like(expert_ids_cuda)
|
expert_ids_triton = torch.empty_like(expert_ids_cuda)
|
||||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||||
|
|
||||||
# 运行两个实现
|
# compare the performance of cuda and triton implementation
|
||||||
moe_align_block_size(
|
moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -206,6 +213,15 @@ seq_length_range = [2**i for i in range(0, 16)]
|
|||||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
|
||||||
|
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||||
|
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||||
|
for i in range(num_tokens):
|
||||||
|
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
|
||||||
|
:topk
|
||||||
|
]
|
||||||
|
return topk_ids
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["batch_size", "seq_len"],
|
x_names=["batch_size", "seq_len"],
|
||||||
@@ -223,9 +239,17 @@ def benchmark(batch_size, seq_len, provider):
|
|||||||
num_experts = 256
|
num_experts = 256
|
||||||
block_size = 128
|
block_size = 128
|
||||||
topk = 8
|
topk = 8
|
||||||
topk_ids = torch.randint(
|
|
||||||
0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda"
|
if USE_RANDOM_PERM:
|
||||||
)
|
topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk)
|
||||||
|
else:
|
||||||
|
topk_ids = torch.randint(
|
||||||
|
0,
|
||||||
|
num_experts,
|
||||||
|
(batch_size * seq_len, topk),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
sorted_ids = torch.empty(
|
sorted_ids = torch.empty(
|
||||||
|
|||||||
Reference in New Issue
Block a user