Add dsv3 fused a gemm to sgl-kernel (#7630)

This commit is contained in:
Ke Bao
2025-06-29 17:52:24 +08:00
committed by GitHub
parent 071a1f51ae
commit 04b35190e2
9 changed files with 800 additions and 0 deletions

View File

@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"