[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)

This commit is contained in:
Elfie Guo
2025-05-16 13:14:07 -07:00
committed by GitHub
parent 839fb31e5f
commit 6fc9357503
12 changed files with 896 additions and 41 deletions

View File

@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
c_strides = torch.full(
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
fp8_blockwise_scaled_grouped_mm(
c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack,
b_stack,
a_scale_stack,
@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
layout_sfb,
problem_sizes,
expert_offsets[:-1],
workspace,
)
for g in range(num_experts):