Support FP4 gemm (1/2) (#3899)
This commit is contained in:
@@ -114,6 +114,17 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
|
||||
|
||||
m.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
m.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor! input,"
|
||||
" Tensor! output_scale, Tensor! input_scale) -> ()");
|
||||
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user