Add CUTLASS FP8 Blockscale MoE kernel for Hopper architecture (#7278)
Co-authored-by: HydraQYH <QYH820@Outlook.com> Co-authored-by: TianQiLin666666 <1834987979@qq.com>
This commit is contained in:
330
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
330
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import deep_gemm
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||||
|
|
||||||
|
|
||||||
|
def get_m_alignment_for_contiguous_layout():
|
||||||
|
return 128
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(x: int, y: int) -> int:
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
pad_size = (128 - (n % 128)) % 128
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||||
|
x_view = x.view(m, -1, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||||
|
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||||
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||||
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||||
|
x_view.size(0), x_view.size(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_contiguous_grouped(
|
||||||
|
num_groups: int, expected_m_per_group: int, k: int, n: int
|
||||||
|
) -> Tuple[
|
||||||
|
int,
|
||||||
|
Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
]:
|
||||||
|
alignment = get_m_alignment_for_contiguous_layout()
|
||||||
|
group_ms = [int(expected_m_per_group) for _ in range(num_groups)]
|
||||||
|
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||||
|
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
|
||||||
|
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for i, group_m in enumerate(group_ms):
|
||||||
|
actual_end = start + group_m
|
||||||
|
aligned_end = start + ceil_div(group_m, alignment) * alignment
|
||||||
|
m_indices[start:actual_end] = i
|
||||||
|
m_indices[actual_end:aligned_end] = -1
|
||||||
|
start = aligned_end
|
||||||
|
|
||||||
|
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||||
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
|
y_fp8 = (
|
||||||
|
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||||
|
torch.empty(
|
||||||
|
(num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i in range(num_groups):
|
||||||
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
|
|
||||||
|
return m, x_fp8, y_fp8, m_indices, out
|
||||||
|
|
||||||
|
|
||||||
|
def bench_deepgemm(
|
||||||
|
expected_m_per_group: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
num_warmup: int,
|
||||||
|
num_run: int,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
# construct tensors
|
||||||
|
m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped(
|
||||||
|
num_groups, expected_m_per_group, k, n
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_deepgemm():
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
x_fp8, y_fp8, out, m_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
run_deepgemm()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# run
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
latencies: list[float] = []
|
||||||
|
start_event.record()
|
||||||
|
for _ in range(num_run):
|
||||||
|
run_deepgemm()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||||
|
|
||||||
|
return avg, m
|
||||||
|
|
||||||
|
|
||||||
|
def bench_cutlass(
|
||||||
|
expected_m_per_group: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
num_warmup: int,
|
||||||
|
num_run: int,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
device = "cuda"
|
||||||
|
alignment = 16
|
||||||
|
n_g = ceil_div(n, alignment) * alignment
|
||||||
|
k_g = ceil_div(k, alignment) * alignment
|
||||||
|
out_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
|
||||||
|
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
|
||||||
|
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||||
|
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
a_tensors = []
|
||||||
|
b_tensors = []
|
||||||
|
a_scales_tensors = []
|
||||||
|
b_scales_tensors = []
|
||||||
|
|
||||||
|
# TODO(@TianQiLin666666): Unique group_ms in all bench function
|
||||||
|
group_ms = [
|
||||||
|
alignment * ceil_div(int(expected_m_per_group), alignment)
|
||||||
|
for _ in range(num_groups)
|
||||||
|
]
|
||||||
|
for g in range(num_groups):
|
||||||
|
m_g = group_ms[g]
|
||||||
|
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||||
|
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||||
|
|
||||||
|
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
|
||||||
|
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||||
|
a_tensors.append(a_g)
|
||||||
|
b_tensors.append(b_g)
|
||||||
|
a_scales_tensors.append(a_scale)
|
||||||
|
b_scales_tensors.append(b_scale)
|
||||||
|
|
||||||
|
a_stack = torch.empty(
|
||||||
|
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
b_stack = torch.empty(
|
||||||
|
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
for g in range(num_groups):
|
||||||
|
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||||
|
b_stack[g] = b_tensors[g].t()
|
||||||
|
b_stack = b_stack.transpose(1, 2)
|
||||||
|
|
||||||
|
a_scale_stack = torch.empty(
|
||||||
|
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
b_scale_stack = torch.empty(
|
||||||
|
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
for g in range(num_groups):
|
||||||
|
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||||
|
b_scale_stack[g] = b_scales_tensors[g].t()
|
||||||
|
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||||
|
|
||||||
|
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||||
|
a_strides = torch.full(
|
||||||
|
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
c_strides = torch.full(
|
||||||
|
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||||
|
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||||
|
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||||
|
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||||
|
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||||
|
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
def run_cutlass():
|
||||||
|
fp8_blockwise_scaled_grouped_mm(
|
||||||
|
c_out,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a_stack,
|
||||||
|
b_stack,
|
||||||
|
a_scale_stack,
|
||||||
|
b_scale_stack,
|
||||||
|
a_strides,
|
||||||
|
a_strides,
|
||||||
|
c_strides,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
run_cutlass()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# run
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_event.record()
|
||||||
|
for _ in range(num_run):
|
||||||
|
run_cutlass()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||||
|
|
||||||
|
return avg, expert_offsets[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def bench_sglang_triton(
|
||||||
|
expected_m_per_group: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
num_warmup: int,
|
||||||
|
num_run: int,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
benchmark_kernels = {
|
||||||
|
"deepgemm": bench_deepgemm,
|
||||||
|
"cutlass": bench_cutlass,
|
||||||
|
# "triton": bench_sglang_triton,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShapeArg:
|
||||||
|
expected_m_per_group: int
|
||||||
|
n: int
|
||||||
|
k: int
|
||||||
|
num_groups: int
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_one_shape(
|
||||||
|
shape_args: List[ShapeArg],
|
||||||
|
num_warmup: int,
|
||||||
|
num_run: int,
|
||||||
|
):
|
||||||
|
for shape in shape_args:
|
||||||
|
print(
|
||||||
|
f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}"
|
||||||
|
)
|
||||||
|
for kernel_name, kernel_func in benchmark_kernels.items():
|
||||||
|
average_time, m = kernel_func(
|
||||||
|
shape.expected_m_per_group,
|
||||||
|
shape.n,
|
||||||
|
shape.k,
|
||||||
|
shape.num_groups,
|
||||||
|
num_warmup,
|
||||||
|
num_run,
|
||||||
|
)
|
||||||
|
print(f"{kernel_name}: {average_time} us")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-warmup", type=int, default=3)
|
||||||
|
parser.add_argument("--num-run", type=int, default=10)
|
||||||
|
shape_args = [
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
|
||||||
|
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
|
||||||
|
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
|
||||||
|
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
|
||||||
|
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
|
||||||
|
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
|
||||||
|
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
|
||||||
|
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
|
||||||
|
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
|
||||||
|
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
|
||||||
|
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
|
||||||
|
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
|
||||||
|
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
|
||||||
|
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
|
||||||
|
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
|
||||||
|
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
|
||||||
|
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
|
||||||
|
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
|
||||||
|
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
|
||||||
|
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
|
||||||
|
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
|
||||||
|
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
|
||||||
|
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
|
||||||
|
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
|
||||||
|
]
|
||||||
|
args = parser.parse_args()
|
||||||
|
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
245
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Executable file → Normal file
245
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Executable file → Normal file
@@ -30,6 +30,126 @@
|
|||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||||
|
|
||||||
|
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||||
|
void launch_sm90_fp8_blockwise_scaled_group_mm(
|
||||||
|
torch::Tensor& out_ptrs,
|
||||||
|
const torch::Tensor& a_ptrs,
|
||||||
|
const torch::Tensor& b_ptrs,
|
||||||
|
const torch::Tensor& a_scales_ptrs,
|
||||||
|
const torch::Tensor& b_scales_ptrs,
|
||||||
|
const torch::Tensor& stride_a,
|
||||||
|
const torch::Tensor& stride_b,
|
||||||
|
const torch::Tensor& stride_c,
|
||||||
|
const torch::Tensor& layout_sfa,
|
||||||
|
const torch::Tensor& layout_sfb,
|
||||||
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& workspace) {
|
||||||
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
|
using ElementB = cutlass::float_e4m3_t;
|
||||||
|
using ElementC = void;
|
||||||
|
using ElementD = OutType;
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = LayoutD;
|
||||||
|
|
||||||
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
typename ScheduleConfig::MmaTileShape,
|
||||||
|
typename ScheduleConfig::ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementC, // Use void to avoid load Matrix C
|
||||||
|
LayoutC*,
|
||||||
|
AlignmentC,
|
||||||
|
ElementD,
|
||||||
|
LayoutC*,
|
||||||
|
AlignmentC,
|
||||||
|
typename ScheduleConfig::EpilogueSchedule,
|
||||||
|
FusionOperation>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
ElementA,
|
||||||
|
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB,
|
||||||
|
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||||
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
|
typename ScheduleConfig::MmaTileShape,
|
||||||
|
typename ScheduleConfig::ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||||
|
|
||||||
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||||
|
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||||
|
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(layout_sfa.data_ptr()),
|
||||||
|
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};
|
||||||
|
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
hw_info.device_id = c10::cuda::current_device();
|
||||||
|
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||||
|
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
{},
|
||||||
|
nullptr,
|
||||||
|
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||||
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||||
|
|
||||||
|
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||||
|
typename GemmKernel::Arguments args{
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||||
|
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||||
|
mainloop_args,
|
||||||
|
epilogue_args,
|
||||||
|
hw_info};
|
||||||
|
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||||
|
|
||||||
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||||
|
|
||||||
|
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||||
|
|
||||||
|
status = gemm_op.run(stream);
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
|
}
|
||||||
|
|
||||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||||
void launch_sm100_fp8_blockwise_scaled_group_mm(
|
void launch_sm100_fp8_blockwise_scaled_group_mm(
|
||||||
torch::Tensor& out_ptrs,
|
torch::Tensor& out_ptrs,
|
||||||
@@ -312,6 +432,74 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||||
|
torch::Tensor& output,
|
||||||
|
torch::Tensor& a_ptrs,
|
||||||
|
torch::Tensor& b_ptrs,
|
||||||
|
torch::Tensor& out_ptrs,
|
||||||
|
torch::Tensor& a_scales_ptrs,
|
||||||
|
torch::Tensor& b_scales_ptrs,
|
||||||
|
const torch::Tensor& a,
|
||||||
|
const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a,
|
||||||
|
const torch::Tensor& scales_b,
|
||||||
|
const torch::Tensor& stride_a,
|
||||||
|
const torch::Tensor& stride_b,
|
||||||
|
const torch::Tensor& stride_c,
|
||||||
|
const torch::Tensor& layout_sfa,
|
||||||
|
const torch::Tensor& layout_sfb,
|
||||||
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& workspace) {
|
||||||
|
struct MmaConfig {
|
||||||
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
|
using MmaTileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||||
|
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
};
|
||||||
|
|
||||||
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
|
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
|
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||||
|
|
||||||
|
run_get_group_gemm_starts<MmaConfig::LayoutSFA, MmaConfig::LayoutSFB, MmaConfig::ScaleConfig>(
|
||||||
|
expert_offsets,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
output,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
problem_sizes_transpose);
|
||||||
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig, cutlass::layout::RowMajor>(
|
||||||
|
out_ptrs,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
stride_c,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
|
* @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
|
||||||
* with per-block scaling.
|
* with per-block scaling.
|
||||||
@@ -397,11 +585,6 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
|
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
|
||||||
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
|
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
|
||||||
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
|
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
|
||||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
|
||||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
|
||||||
TORCH_CHECK(
|
|
||||||
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
|
|
||||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
|
|
||||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||||
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
|
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
|
||||||
|
|
||||||
@@ -455,7 +638,57 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
can_implement = true;
|
can_implement = true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||||
|
if (sm_version == 90 && a.size(1) > 256) {
|
||||||
|
if (output.scalar_type() == torch::kBFloat16) {
|
||||||
|
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||||
|
output,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
stride_c,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
} else {
|
||||||
|
sm90_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
|
||||||
|
output,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
stride_c,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
}
|
||||||
|
can_implement = true;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
can_implement,
|
||||||
|
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ",
|
||||||
|
sm_version,
|
||||||
|
a.size(1));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,9 +53,15 @@ def is_sm100_supported(device=None) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_sm90_supported(device=None) -> bool:
|
||||||
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||||
|
torch.version.cuda >= "12.8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not is_sm100_supported(),
|
not (is_sm100_supported() or is_sm90_supported()),
|
||||||
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
|
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||||
@@ -162,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
baseline = baseline_tensors[g]
|
baseline = baseline_tensors[g]
|
||||||
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
|
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
|
||||||
torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=5e-4)
|
torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=1e-3)
|
||||||
print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK")
|
print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user