[Feat] Update sgl-kernel flashinfer to latest main version (#5500)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
PGFLMG
2025-04-18 03:43:23 +08:00
committed by GitHub
parent f13d65a7ea
commit c08a717c77
8 changed files with 393 additions and 133 deletions

View File

@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/elementwise
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
/*