Support FP4 gemm (1/2) (#3899)

This commit is contained in:
Trevor Morris
2025-03-24 19:50:23 -07:00
committed by GitHub
parent 22c3702e1e
commit e9f8e42318
11 changed files with 1245 additions and 5 deletions

View File

@@ -114,6 +114,17 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
m.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
m.def(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
/*
* From csrc/moe
*/