Revert "[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x)" (#3982)

This commit is contained in:
Chayenne
2025-02-28 23:57:17 -08:00
committed by GitHub
parent 6b859e7ddd
commit 18bb216c28
5 changed files with 94 additions and 381 deletions

View File

@@ -99,12 +99,13 @@ def moe_align_block_size_triton(
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
tokens_cnts: torch.Tensor,
cumsum: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
@@ -138,18 +139,11 @@ def moe_align_block_size_triton(
)
def calculate_diff(batch_size, seq_len, num_experts):
num_experts = num_experts
def calculate_diff(batch_size, seq_len):
num_experts = 256
block_size = 128
topk = 8
assert batch_size >= 1
assert seq_len >= 1
assert num_experts >= 4
if topk > num_experts:
topk = num_experts
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
@@ -181,13 +175,6 @@ def calculate_diff(batch_size, seq_len, num_experts):
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
token_cnts_buffer_triton = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer_triton = torch.zeros(
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
)
# compare the performance of cuda and triton implementation
moe_align_block_size(
topk_ids,
@@ -206,27 +193,14 @@ def calculate_diff(batch_size, seq_len, num_experts):
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
token_cnts_buffer_triton,
cumsum_buffer_triton,
)
sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values
sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values
if (
torch.allclose(expert_ids_cuda, expert_ids_triton)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton)
and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot)
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
):
print(
"✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format(
batch_size * seq_len, num_experts
)
)
print("✅ CUDA and Triton implementations match")
else:
print("❌ CUDA and Triton implementations do not match")
print("CUDA sorted ids:", sorted_ids_cuda_snapshot)
print("Triton sorted ids:", sorted_ids_triton_snapshot)
print("CUDA expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton)
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda)
@@ -282,7 +256,7 @@ def benchmark(batch_size, seq_len, provider):
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.zeros(
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
@@ -293,37 +267,34 @@ def benchmark(batch_size, seq_len, provider):
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
# Warm up
api_func = (
moe_align_block_size if provider == "cuda" else moe_align_block_size_triton
)
for _ in range(10):
api_func(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer.clone(),
cumsum_buffer.clone(),
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer.clone(),
cumsum_buffer.clone(),
),
quantiles=quantiles,
)
if provider == "cuda":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
@@ -335,22 +306,8 @@ if __name__ == "__main__":
default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results",
)
parser.add_argument(
"--verify",
action="store_true",
help="verify kernel",
)
args = parser.parse_args()
if args.verify:
num_experts_range = [2**i for i in range(3, 9)]
calculate_diff(batch_size=4, seq_len=1024)
configs = list(
itertools.product(batch_size_range, seq_length_range, num_experts_range)
)
for bs, seq, num_experts in configs:
calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts)
else:
benchmark.run(print_data=True, save_path=args.save_path)
benchmark.run(print_data=True)