[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
This commit is contained in:
@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
c_strides = torch.full(
|
||||
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c_out,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_stack,
|
||||
b_stack,
|
||||
a_scale_stack,
|
||||
@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
|
||||
Reference in New Issue
Block a user