support cmake for sgl-kernel (#4706)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -15,8 +15,11 @@ limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -253,23 +256,12 @@ void min_p_sampling_from_probs(
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val,
|
||||
int64_t cuda_stream);
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
inline void top_k_renorm_probs_wrapper(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
int64_t cuda_stream);
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
|
||||
Reference in New Issue
Block a user