2025-10-13 11:19:21 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def es_fp8_blockwise_scaled_grouped_mm(
|
|
|
|
|
output,
|
|
|
|
|
a,
|
|
|
|
|
b,
|
|
|
|
|
scales_a,
|
|
|
|
|
scales_b,
|
|
|
|
|
stride_a,
|
|
|
|
|
stride_b,
|
|
|
|
|
stride_d,
|
|
|
|
|
problem_sizes,
|
|
|
|
|
expert_offsets,
|
2025-10-16 04:39:31 +08:00
|
|
|
workspace,
|
2025-10-13 11:19:21 +08:00
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
|
|
|
|
|
output,
|
|
|
|
|
a,
|
|
|
|
|
b,
|
|
|
|
|
scales_a,
|
|
|
|
|
scales_b,
|
|
|
|
|
stride_a,
|
|
|
|
|
stride_b,
|
|
|
|
|
stride_d,
|
|
|
|
|
problem_sizes,
|
|
|
|
|
expert_offsets,
|
2025-10-16 04:39:31 +08:00
|
|
|
workspace,
|
2025-10-13 11:19:21 +08:00
|
|
|
)
|