refine sgl_moe_align_block_size_benchmark (#4327)
This commit is contained in:
@@ -4,7 +4,8 @@ import itertools
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel import moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
USE_RANDOM_PERM = False
|
USE_RANDOM_PERM = False
|
||||||
|
|
||||||
@@ -139,15 +140,11 @@ def moe_align_block_size_triton(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff(batch_size, seq_len):
|
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||||
num_experts = 256
|
|
||||||
block_size = 128
|
|
||||||
topk = 8
|
|
||||||
|
|
||||||
topk_ids = torch.stack(
|
topk_ids = torch.stack(
|
||||||
[
|
[
|
||||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||||
for _ in range(batch_size * seq_len)
|
for _ in range(num_tokens)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len):
|
|||||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||||
|
|
||||||
# compare the performance of cuda and triton implementation
|
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
|
||||||
moe_align_block_size(
|
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||||
|
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
|
||||||
|
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
|
||||||
|
|
||||||
|
# compare the performance of cuda, triton and vllm implementation
|
||||||
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len):
|
|||||||
expert_ids_triton,
|
expert_ids_triton,
|
||||||
num_tokens_post_pad_triton,
|
num_tokens_post_pad_triton,
|
||||||
)
|
)
|
||||||
|
ops.moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids_vllm,
|
||||||
|
expert_ids_vllm,
|
||||||
|
num_tokens_post_pad_vllm,
|
||||||
|
)
|
||||||
|
|
||||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
||||||
):
|
):
|
||||||
print("✅ CUDA and Triton implementations match")
|
print("✅ SGL and Triton implementations match")
|
||||||
else:
|
else:
|
||||||
print("❌ CUDA and Triton implementations do not match")
|
print("❌ SGL and Triton implementations do not match")
|
||||||
print("CUDA expert_ids:", expert_ids_cuda)
|
print("SGL expert_ids:", expert_ids_cuda)
|
||||||
print("Triton expert_ids:", expert_ids_triton)
|
print("Triton expert_ids:", expert_ids_triton)
|
||||||
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||||
|
|
||||||
|
if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose(
|
||||||
|
num_tokens_post_pad_cuda, num_tokens_post_pad_vllm
|
||||||
|
):
|
||||||
|
print("✅ SGL and VLLM implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ SGL and VLLM implementations do not match")
|
||||||
|
print("SGL expert_ids:", expert_ids_cuda)
|
||||||
|
print("VLLM expert_ids:", expert_ids_vllm)
|
||||||
|
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||||
|
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||||
|
|
||||||
batch_size_range = [2**i for i in range(0, 8)]
|
|
||||||
seq_length_range = [2**i for i in range(0, 16)]
|
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
num_experts_range = [32, 64, 128, 256]
|
||||||
|
topk_range = [2, 4, 8]
|
||||||
|
|
||||||
|
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||||
|
|
||||||
|
|
||||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||||
@@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
|||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["batch_size", "seq_len"],
|
x_names=["num_tokens", "num_experts", "topk"],
|
||||||
x_vals=[list(_) for _ in configs],
|
x_vals=configs,
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["cuda", "triton"],
|
line_vals=["sgl", "triton", "vllm"],
|
||||||
line_names=["CUDA", "Triton"],
|
line_names=["SGL", "Triton", "VLLM"],
|
||||||
styles=[("blue", "-"), ("red", "-")],
|
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||||
ylabel="us",
|
ylabel="us",
|
||||||
plot_name="moe-align-block-size-performance",
|
plot_name="moe-align-block-size-performance",
|
||||||
args={},
|
args={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(batch_size, seq_len, provider):
|
def benchmark(num_tokens, num_experts, topk, provider):
|
||||||
num_experts = 256
|
|
||||||
block_size = 128
|
block_size = 128
|
||||||
topk = 8
|
|
||||||
|
|
||||||
if USE_RANDOM_PERM:
|
if USE_RANDOM_PERM:
|
||||||
topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk)
|
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||||
else:
|
else:
|
||||||
topk_ids = torch.randint(
|
topk_ids = torch.randint(
|
||||||
0,
|
0,
|
||||||
num_experts,
|
num_experts,
|
||||||
(batch_size * seq_len, topk),
|
(num_tokens, topk),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "cuda":
|
if provider == "sgl":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: moe_align_block_size(
|
lambda: sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider):
|
|||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
else:
|
elif provider == "triton":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: moe_align_block_size_triton(
|
lambda: moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider):
|
|||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
else: # vllm
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: ops.moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids.clone(),
|
||||||
|
expert_ids.clone(),
|
||||||
|
num_tokens_post_pad.clone(),
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
@@ -306,8 +339,22 @@ if __name__ == "__main__":
|
|||||||
default="./configs/benchmark_ops/moe_align_blocks/",
|
default="./configs/benchmark_ops/moe_align_blocks/",
|
||||||
help="Path to save moe align benchmark results",
|
help="Path to save moe align benchmark results",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_experts",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
choices=[8, 64, 128, 256],
|
||||||
|
help="Number of experts for benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
choices=[2, 4, 8],
|
||||||
|
help="Top-k value for benchmark",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
calculate_diff(batch_size=4, seq_len=1024)
|
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||||
|
|
||||||
benchmark.run(print_data=True)
|
benchmark.run(print_data=True)
|
||||||
Reference in New Issue
Block a user