Files
sglang/sgl-kernel/python/sgl_kernel/expert_specialization.py

28 lines
443 B
Python

import torch
def es_fp8_blockwise_scaled_grouped_mm(
output,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_d,
problem_sizes,
expert_offsets,
):
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
output,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_d,
problem_sizes,
expert_offsets,
)