[sgl-kernel] fix: fix cu118 compile error (#6123)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -25,7 +25,21 @@ limitations under the License.
|
|||||||
#include <device/sm100_mla.hpp>
|
#include <device/sm100_mla.hpp>
|
||||||
#include <kernel/sm100_mla_tile_scheduler.hpp>
|
#include <kernel/sm100_mla_tile_scheduler.hpp>
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
// clang-format off
|
||||||
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||||
|
void cutlass_mla_decode(
|
||||||
|
torch::Tensor const& out,
|
||||||
|
torch::Tensor const& q_nope_and_q_pe,
|
||||||
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
|
torch::Tensor const& seq_lens,
|
||||||
|
torch::Tensor const& page_table,
|
||||||
|
torch::Tensor const& workspace) {
|
||||||
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||||
|
}
|
||||||
|
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
|
||||||
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
#define CUTLASS_CHECK(status) \
|
#define CUTLASS_CHECK(status) \
|
||||||
{ \
|
{ \
|
||||||
@@ -209,3 +223,4 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
// clang-format on
|
||||||
|
|||||||
@@ -24,9 +24,12 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||||
|
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) {
|
||||||
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
#ifndef CUDART_INF_FP16
|
#ifndef CUDART_INF_FP16
|
||||||
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
|
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
|
||||||
@@ -252,3 +255,4 @@ void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optiona
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
// clang-format on
|
||||||
|
|||||||
Reference in New Issue
Block a user