[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)

This commit is contained in:
Kaixi Hou
2025-08-22 12:19:45 -07:00
committed by GitHub
parent f556ac8bd8
commit e5638573c1
7 changed files with 420 additions and 13 deletions

View File

@@ -157,6 +157,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor output_scale_offset_by_experts) -> ()");
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);
m.def(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"