support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
@@ -73,6 +73,14 @@ FetchContent_Declare(
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
# mscclpp
|
||||
FetchContent_Declare(
|
||||
repo-mscclpp
|
||||
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
|
||||
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-mscclpp)
|
||||
|
||||
# ccache option
|
||||
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
||||
@@ -99,6 +107,7 @@ include_directories(
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
)
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
|
||||
OUTPUT_VARIABLE TORCH_CXX11_ABI
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
if(TORCH_CXX11_ABI STREQUAL "0")
|
||||
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
else()
|
||||
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
endif()
|
||||
set(MSCCLPP_USE_CUDA ON)
|
||||
set(MSCCLPP_BYPASS_GPU_CHECK ON)
|
||||
set(MSCCLPP_BUILD_TESTS OFF)
|
||||
add_subdirectory(${repo-mscclpp_SOURCE_DIR})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
|
||||
Reference in New Issue
Block a user