diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index a9deeb9a7..c83cf49ad 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -3,11 +3,8 @@ #pragma once -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" +#include +#include namespace cutlass { namespace epilogue { diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 10be552a8..33e82decc 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -2,16 +2,9 @@ // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h #pragma once -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/numeric_types.h" -#include "cutlass/trace.h" +#include +#include +#include //////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index cf0b9cfa3..674e191a0 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -3,14 +3,11 @@ #pragma once -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" -#include "cutlass/trace.h" -#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" +#include +#include +#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index b9879b114..99d0326cf 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -1,3 +1,5 @@ +#include + #include "utils.hpp" // trt_reduce diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index d80beedec..d647c3496 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -3,9 +3,6 @@ #include #include -#include -#include -#include #include "trt_reduce_internal.cuh"