#pragma once #include #include #include #include #include #include #include #include #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if CUDART_VERSION >= 12050 #include #endif // CUDART_VERSION >= 12050 #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define cublasComputeType_t cudaDataType_t #endif // CUDART_VERSION < 11020 #if CUB_VERSION >= 200800 #include using CubAddOp = cuda::std::plus<>; using CubMaxOp = cuda::maximum<>; #else // if CUB_VERSION < 200800 using CubAddOp = cub::Sum; using CubMaxOp = cub::Max; #endif // CUB_VERSION