[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster (#4396)

This commit is contained in:
Chunan Zeng
2025-03-23 23:44:17 -07:00
committed by GitHub
parent 3980ff1be6
commit 65c24c28f9
8 changed files with 191 additions and 127 deletions

View File

@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);