From 195a09f57c2c966d13473e0d99714f3880e3138b Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 30 Mar 2025 12:15:20 -0700 Subject: [PATCH] fix bmm fp8 (#4926) --- sgl-kernel/csrc/torch_extension.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 3633c9f40..263a9d15c 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -82,7 +82,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From FlashInfer */ - m.def("bmm_fp8", bmm_fp8); + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); m.def("top_k_renorm_probs", top_k_renorm_probs); m.def("top_p_renorm_probs", top_p_renorm_probs);