support cmake for sgl-kernel (#4706)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
@@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
@@ -178,9 +178,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
|
||||
Reference in New Issue
Block a user