Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)

This commit is contained in:
Ke Bao
2025-06-25 08:44:27 +08:00
committed by GitHub
parent 112b496a6c
commit 57ab776910
7 changed files with 163 additions and 70 deletions

View File

@@ -5,7 +5,11 @@ import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
USE_RANDOM_PERM = False
@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except RuntimeError as e:
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
return topk_ids
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=False,
):
if not pad_sorted_token_ids:
sorted_ids.fill_(topk_ids.numel())
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
pad_sorted_token_ids,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sgl", "triton", "vllm"],
line_names=["SGL", "Triton", "VLLM"],
line_vals=["sgl", "sgl_fusion", "triton"],
line_names=["SGL", "SGL Fusion", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="moe-align-block-size-performance",
@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -297,35 +334,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
):
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
@@ -337,7 +345,21 @@ def benchmark(num_tokens, num_experts, topk, provider):
),
quantiles=quantiles,
)
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=True,
),
quantiles=quantiles,
)
elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
),
quantiles=quantiles,
)
else: # vllm
try:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
except RuntimeError as e:
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
# Return extreme values to indicate failure in the chart
return float("inf"), float("inf"), float("inf")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms