support cmake for sgl-kernel (#4706)

Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
Yineng Zhang
2025-03-27 01:42:28 -07:00
committed by GitHub
parent 1b9175cb23
commit 8bf6d7f406
18 changed files with 426 additions and 36 deletions

View File

@@ -15,8 +15,11 @@ limitations under the License.
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <vector>
@@ -253,23 +256,12 @@ void min_p_sampling_from_probs(
double min_p_val,
bool deterministic,
int64_t cuda_stream);
// top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val,
int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline void top_k_renorm_probs_wrapper(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream) {
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
}
int64_t cuda_stream);
void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,