Files
sglang/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py

339 lines
11 KiB
Python
Raw Permalink Normal View History

import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import torch
from sgl_kernel import (
es_fp8_blockwise_scaled_grouped_mm,
fp8_blockwise_scaled_grouped_mm,
)
random.seed(28)
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def create_unbalanced_expert_token_distribution(max_num_experts):
ratios = [random.random() for _ in range(max_num_experts)]
def convert_to_tokens(ratio: float):
if ratio <= 0.7:
return random.randint(1, 32)
elif ratio > 0.7 and ratio <= 0.85:
return random.randint(32, 64)
elif ratio > 0.85 and ratio <= 0.95:
return random.randint(64, 128)
elif ratio > 0.95:
return random.randint(128, 1024)
else:
return 128
group_ms = [convert_to_tokens(ratio) for ratio in ratios]
return group_ms
group_ms = create_unbalanced_expert_token_distribution(8192)
# group_ms = [128 for _ in range(8192)]
# group_ms = [128 if i % 2 == 0 else 64 for i in range(8192)]
def bench_es(
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
device = "cuda"
alignment = 128
n_g = ceil_div(n, alignment) * alignment
k_g = ceil_div(k, alignment) * alignment
out_dtype = torch.bfloat16
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
a_tensors = []
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
if False:
print("Token Distributtion: ", group_ms[0:num_groups])
print("Token Count: ", sum(group_ms[0:num_groups]))
for g in range(num_groups):
m_g = group_ms[g]
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
b_scales_tensors.append(b_scale)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_stack = torch.empty(
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_groups):
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_stack[g] = b_tensors[g].t()
b_stack = b_stack.transpose(1, 2)
a_scale_stack = torch.empty(
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_groups):
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
b_scale_stack[g] = b_scales_tensors[g].t()
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
)
d_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
def run_cutlass():
es_fp8_blockwise_scaled_grouped_mm(
c_out,
a_stack,
b_stack,
a_scale_stack,
b_scale_stack,
a_strides,
a_strides,
d_strides,
problem_sizes,
expert_offsets[:-1],
workspace,
)
run_cutlass()
# warmup
for _ in range(num_warmup):
run_cutlass()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_run):
run_cutlass()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, expert_offsets[-1]
def bench_sgl(
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
device = "cuda"
alignment = 128
n_g = ceil_div(n, alignment) * alignment
k_g = ceil_div(k, alignment) * alignment
out_dtype = torch.bfloat16
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
a_tensors = []
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
for g in range(num_groups):
m_g = group_ms[g]
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
b_scales_tensors.append(b_scale)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_stack = torch.empty(
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_groups):
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_stack[g] = b_tensors[g].t()
b_stack = b_stack.transpose(1, 2)
a_scale_stack = torch.empty(
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_groups):
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
b_scale_stack[g] = b_scales_tensors[g].t()
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
)
c_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
def run_cutlass():
fp8_blockwise_scaled_grouped_mm(
c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack,
b_stack,
a_scale_stack,
b_scale_stack,
a_strides,
a_strides,
c_strides,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets[:-1],
workspace,
)
# warmup
for _ in range(num_warmup):
run_cutlass()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_run):
run_cutlass()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, expert_offsets[-1]
benchmark_kernels = {"es": bench_es, "sgl-kernel": bench_sgl}
@dataclass
class ShapeArg:
n: int
k: int
num_groups: int
def benchmark_one_shape(
shape_args: List[ShapeArg],
num_warmup: int,
num_run: int,
):
for shape in shape_args:
print(f"\nBenchmark: n={shape.n}, k={shape.k}, num_groups={shape.num_groups}")
for kernel_name, kernel_func in benchmark_kernels.items():
average_time, m = kernel_func(
shape.n,
shape.k,
shape.num_groups,
num_warmup,
num_run,
)
print(f"{kernel_name}: {average_time} us")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=20)
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, down, chunk_size = 4096, TP = 8
ShapeArg(n=7168, k=256, num_groups=256),
# Prefill, Qwen3-235B-A22B-FP8, gateup, TP = 4
ShapeArg(n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, TP = 4
ShapeArg(n=4096, k=384, num_groups=128),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(n=4096, k=7168, num_groups=32),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(n=4096, k=7168, num_groups=16),
]
args = parser.parse_args()
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
if __name__ == "__main__":
main()