diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 129402112..a342dee10 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -16,7 +16,9 @@ limitations under the License. #pragma once #include +#ifndef USE_ROCM #include +#endif #include #include @@ -63,6 +65,7 @@ inline int getSMVersion() { return sm_major * 10 + sm_minor; } +#ifndef USE_ROCM #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ @@ -79,6 +82,7 @@ inline int getSMVersion() { return false; \ } \ }() +#endif #define DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \