support cmake for sgl-kernel (#4706)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -15,8 +15,11 @@ limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -253,23 +256,12 @@ void min_p_sampling_from_probs(
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val,
|
||||
int64_t cuda_stream);
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
inline void top_k_renorm_probs_wrapper(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
int64_t cuda_stream);
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
|
||||
@@ -15,14 +15,190 @@ limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <cuda_runtime.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <pytorch_extension_utils.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Adapt from FlashInfer
|
||||
#ifdef FLASHINFER_ENABLE_F16
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = nv_half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_F16(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_BF16
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E4M3
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
|
||||
case at::ScalarType::Float8_e4m3fn: { \
|
||||
using c_type = __nv_fp8_e4m3; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E5M2
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
|
||||
case at::ScalarType::Float8_e5m2: { \
|
||||
using c_type = __nv_fp8_e5m2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
|
||||
#endif
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH(var_name, cond, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (cond) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pack_u16(cond1, cond2)) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
|
||||
<< int(cond2) << ")"; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_CASE(case_expr, case_var, ...) \
|
||||
case case_expr: { \
|
||||
constexpr auto case_var = case_expr; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
|
||||
case pack_u16(case_expr1, case_expr2): { \
|
||||
constexpr auto case_var1 = case_expr1; \
|
||||
constexpr auto case_var2 = case_expr2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
|
||||
TORCH_CHECK( \
|
||||
num_qo_heads % num_kv_heads == 0, \
|
||||
"num_qo_heads(", \
|
||||
num_qo_heads, \
|
||||
") must be divisible by num_kv_heads(", \
|
||||
num_kv_heads, \
|
||||
")")
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
inline bool is_float8_tensor(const at::Tensor& tensor) {
|
||||
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
|
||||
Reference in New Issue
Block a user