Add fp8 gemm kernel for CPU in sgl-kernel and add gemm UT (#6216)
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -94,6 +94,16 @@ at::Tensor int8_scaled_mm_cpu(
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// fp8 gemm
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// quant + igemm
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
@@ -198,6 +208,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// igemm
|
||||
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
|
||||
|
||||
// fp8 gemm
|
||||
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX");
|
||||
|
||||
// quant + igemm
|
||||
m.def(
|
||||
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
|
||||
|
||||
Reference in New Issue
Block a user