Remove annoying warnings in sgl kernel build (#9905)
This commit is contained in:
@@ -3,6 +3,7 @@ project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# CMake
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
cmake_policy(SET CMP0177 NEW)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
||||
@@ -50,14 +51,7 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-fmt)
|
||||
|
||||
# DeepGEMM
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||
@@ -66,6 +60,14 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-fmt)
|
||||
|
||||
# Triton
|
||||
FetchContent_Declare(
|
||||
repo-triton
|
||||
@@ -148,21 +150,40 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"--expt-extended-lambda"
|
||||
"--threads=32"
|
||||
|
||||
# Suppress warnings
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
# Supress warnings
|
||||
"-Xcompiler=-Wno-clang-format-violations"
|
||||
"-Xcompiler=-Wno-conversion"
|
||||
"-Xcompiler=-Wno-deprecated-declarations"
|
||||
"-Xcompiler=-Wno-terminate"
|
||||
"-Xcompiler=-Wfatal-errors"
|
||||
"-Xcompiler=-ftemplate-backtrace-limit=1"
|
||||
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
||||
|
||||
# uncomment to debug
|
||||
# "--ptxas-options=-v"
|
||||
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
|
||||
)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
|
||||
if (SGL_KERNEL_ENABLE_BF16)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_BF16"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_FP8)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_FP8"
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (ENABLE_BELOW_SM90)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -210,31 +231,12 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_BF16)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_BF16"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_FP8)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_FP8"
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DENABLE_NVFP4=1"
|
||||
)
|
||||
endif()
|
||||
|
||||
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
|
||||
@@ -21,12 +21,11 @@ submodule: ## Initialize and update git submodules
|
||||
ln: submodule ## Create compilation database
|
||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||
|
||||
|
||||
install: submodule ## Install package in development mode
|
||||
@pip install -e . --no-build-isolation
|
||||
|
||||
build: install-deps submodule ## Build and install wheel package
|
||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
@rm -rf dist/* || true && CMAKE_POLICY_VERSION_MINIMUM=3.5 MAX_JOBS=$(nproc) CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
|
||||
clean: ## Remove build artifacts
|
||||
@rm -rf build dist *.egg-info
|
||||
|
||||
@@ -162,7 +162,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
@@ -259,7 +259,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
|
||||
@@ -131,6 +131,7 @@ __device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
|
||||
: "r"(smem_int_ptr), "r"(phase_bit));
|
||||
return static_cast<bool>(wait_complete);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
// Barrier arrive
|
||||
|
||||
@@ -541,6 +541,11 @@ void quant_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid redefinition warnings
|
||||
#undef CHECK_CONTIGUOUS
|
||||
#undef CHECK_TH_CUDA
|
||||
#undef CHECK_INPUT
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
|
||||
Reference in New Issue
Block a user