[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
This commit is contained in:
@@ -47,6 +47,7 @@ from sgl_kernel.moe import (
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
moe_fused_gate,
|
||||
prepare_moe_input,
|
||||
topk_softmax,
|
||||
)
|
||||
from sgl_kernel.sampling import (
|
||||
|
||||
@@ -64,6 +64,11 @@ def moe_fused_gate(
|
||||
|
||||
def fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
):
|
||||
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
)
|
||||
|
||||
|
||||
def prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user