Add dsv3 fused a gemm to sgl-kernel (#7630)
This commit is contained in:
@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
" Tensor! output_scale, Tensor! input_scale) -> ()");
|
||||
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
m.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
|
||||
Reference in New Issue
Block a user