Remove annoying warnings in sgl kernel build (#9905)
This commit is contained in:
@@ -3,6 +3,7 @@ project(sgl-kernel LANGUAGES CXX CUDA)
|
|||||||
|
|
||||||
# CMake
|
# CMake
|
||||||
cmake_policy(SET CMP0169 OLD)
|
cmake_policy(SET CMP0169 OLD)
|
||||||
|
cmake_policy(SET CMP0177 NEW)
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||||
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
||||||
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
||||||
@@ -50,14 +51,7 @@ FetchContent_Declare(
|
|||||||
)
|
)
|
||||||
FetchContent_Populate(repo-cutlass)
|
FetchContent_Populate(repo-cutlass)
|
||||||
|
|
||||||
FetchContent_Declare(
|
# DeepGEMM
|
||||||
repo-fmt
|
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
|
||||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
|
||||||
GIT_SHALLOW OFF
|
|
||||||
)
|
|
||||||
FetchContent_Populate(repo-fmt)
|
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-deepgemm
|
repo-deepgemm
|
||||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||||
@@ -66,6 +60,14 @@ FetchContent_Declare(
|
|||||||
)
|
)
|
||||||
FetchContent_Populate(repo-deepgemm)
|
FetchContent_Populate(repo-deepgemm)
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
repo-fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||||
|
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||||
|
GIT_SHALLOW OFF
|
||||||
|
)
|
||||||
|
FetchContent_Populate(repo-fmt)
|
||||||
|
|
||||||
# Triton
|
# Triton
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-triton
|
repo-triton
|
||||||
@@ -148,21 +150,40 @@ set(SGL_KERNEL_CUDA_FLAGS
|
|||||||
"--expt-extended-lambda"
|
"--expt-extended-lambda"
|
||||||
"--threads=32"
|
"--threads=32"
|
||||||
|
|
||||||
# Suppress warnings
|
# Supress warnings
|
||||||
"-Xcompiler=-Wconversion"
|
"-Xcompiler=-Wno-clang-format-violations"
|
||||||
"-Xcompiler=-fno-strict-aliasing"
|
"-Xcompiler=-Wno-conversion"
|
||||||
|
"-Xcompiler=-Wno-deprecated-declarations"
|
||||||
|
"-Xcompiler=-Wno-terminate"
|
||||||
|
"-Xcompiler=-Wfatal-errors"
|
||||||
|
"-Xcompiler=-ftemplate-backtrace-limit=1"
|
||||||
|
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
||||||
|
|
||||||
# uncomment to debug
|
# uncomment to debug
|
||||||
# "--ptxas-options=-v"
|
# "--ptxas-options=-v"
|
||||||
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
|
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
|
||||||
)
|
)
|
||||||
|
|
||||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
|
||||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
|
||||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||||
|
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||||
|
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||||
|
|
||||||
|
if (SGL_KERNEL_ENABLE_BF16)
|
||||||
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
|
"-DFLASHINFER_ENABLE_BF16"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SGL_KERNEL_ENABLE_FP8)
|
||||||
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
|
"-DFLASHINFER_ENABLE_FP8"
|
||||||
|
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
||||||
|
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (ENABLE_BELOW_SM90)
|
if (ENABLE_BELOW_SM90)
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
@@ -210,31 +231,12 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (SGL_KERNEL_ENABLE_BF16)
|
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
|
||||||
"-DFLASHINFER_ENABLE_BF16"
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (SGL_KERNEL_ENABLE_FP8)
|
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
|
||||||
"-DFLASHINFER_ENABLE_FP8"
|
|
||||||
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
|
||||||
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
|
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-DENABLE_NVFP4=1"
|
"-DENABLE_NVFP4=1"
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
||||||
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
||||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
||||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
"csrc/allreduce/custom_all_reduce.cu"
|
"csrc/allreduce/custom_all_reduce.cu"
|
||||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||||
|
|||||||
@@ -21,12 +21,11 @@ submodule: ## Initialize and update git submodules
|
|||||||
ln: submodule ## Create compilation database
|
ln: submodule ## Create compilation database
|
||||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||||
|
|
||||||
|
|
||||||
install: submodule ## Install package in development mode
|
install: submodule ## Install package in development mode
|
||||||
@pip install -e . --no-build-isolation
|
@pip install -e . --no-build-isolation
|
||||||
|
|
||||||
build: install-deps submodule ## Build and install wheel package
|
build: install-deps submodule ## Build and install wheel package
|
||||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
@rm -rf dist/* || true && CMAKE_POLICY_VERSION_MINIMUM=3.5 MAX_JOBS=$(nproc) CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||||
|
|
||||||
clean: ## Remove build artifacts
|
clean: ## Remove build artifacts
|
||||||
@rm -rf build dist *.egg-info
|
@rm -rf build dist *.egg-info
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
// TODO(trevor-m): Change split_kv back to -1 when
|
// TODO(trevor-m): Change split_kv back to -1 when
|
||||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||||
// perform worse with larger context length and smaller batch sizes.
|
// perform worse with larger context length and smaller batch sizes.
|
||||||
num_kv_splits, // split_kv
|
static_cast<int>(num_kv_splits), // split_kv
|
||||||
nullptr, // is_var_split_kv
|
nullptr, // is_var_split_kv
|
||||||
};
|
};
|
||||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||||
@@ -259,7 +259,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
|
|||||||
// Assumes device 0 when getting sm_count.
|
// Assumes device 0 when getting sm_count.
|
||||||
arguments.hw_info.sm_count =
|
arguments.hw_info.sm_count =
|
||||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||||
arguments.split_kv = num_kv_splits;
|
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||||
|
|
||||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ __device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
|
|||||||
: "r"(smem_int_ptr), "r"(phase_bit));
|
: "r"(smem_int_ptr), "r"(phase_bit));
|
||||||
return static_cast<bool>(wait_complete);
|
return static_cast<bool>(wait_complete);
|
||||||
#endif
|
#endif
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Barrier arrive
|
// Barrier arrive
|
||||||
|
|||||||
@@ -541,6 +541,11 @@ void quant_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Avoid redefinition warnings
|
||||||
|
#undef CHECK_CONTIGUOUS
|
||||||
|
#undef CHECK_TH_CUDA
|
||||||
|
#undef CHECK_INPUT
|
||||||
|
|
||||||
/*Quantization entry for fp4 experts quantization*/
|
/*Quantization entry for fp4 experts quantization*/
|
||||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||||
|
|||||||
Reference in New Issue
Block a user