From 2bd9bd4cc2a5df9aa734c16e0cc187327b6528c4 Mon Sep 17 00:00:00 2001 From: xiezhongtao Date: Tue, 20 Jan 2026 10:14:31 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=BB=9F=E4=B8=80=E7=A1=AC?= =?UTF-8?q?=E4=BB=B6=E7=9B=B8=E5=85=B3=E5=A4=B4=E6=96=87=E4=BB=B6=E5=BC=95?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。 --- CMakeLists.txt | 948 +----------------- cmake/comm.cmake | 23 + cmake/cuda.cmake | 753 ++++++++++++++ cmake/hip.cmake | 146 +++ cmake/musa.cmake | 125 +++ csrc/activation_kernels.cu | 5 +- csrc/attention/attention_kernels.cuh | 6 +- csrc/attention/dtype_bfloat16.cuh | 10 +- csrc/attention/dtype_float16.cuh | 4 +- csrc/attention/dtype_fp8.cuh | 6 +- csrc/attention/merge_attn_states.cu | 4 +- csrc/attention/vertical_slash_index.cu | 4 +- csrc/cache.h | 4 +- csrc/cache_kernels.cu | 13 +- csrc/cub_helpers.h | 18 +- csrc/cuda_utils_kernels.cu | 5 +- csrc/cuda_view.cu | 5 +- csrc/cumem_allocator_compat.h | 4 +- csrc/custom_all_reduce.cu | 6 +- csrc/custom_all_reduce.cuh | 8 +- csrc/custom_all_reduce_test.cu | 4 +- csrc/custom_quickreduce.cu | 6 +- csrc/dispatch_utils.h | 3 +- csrc/fused_qknorm_rope_kernel.cu | 5 +- csrc/launch_bounds_utils.h | 3 +- csrc/layernorm_kernels.cu | 4 +- csrc/layernorm_quant_kernels.cu | 4 +- csrc/mamba/mamba_ssm/selective_scan.h | 7 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 24 +- csrc/moe/dynamic_4bit_int_moe_cpu.cpp | 5 +- csrc/moe/grouped_topk_kernels.cu | 9 +- csrc/moe/moe_align_sum_kernels.cu | 8 +- csrc/moe/moe_ops.h | 2 +- csrc/moe/moe_permute_unpermute_op.cu | 4 +- csrc/moe/moe_wna16.cu | 8 +- csrc/moe/moe_wna16_utils.h | 3 +- csrc/moe/permute_unpermute_kernels/dispatch.h | 2 +- .../moe_permute_unpermute_kernel.h | 7 +- csrc/moe/topk_softmax_kernels.cu | 6 +- csrc/ops.h | 1 + csrc/permute_cols.cu | 5 +- csrc/pos_encoding_kernels.cu | 5 +- csrc/quantization/activation_kernels.cu | 26 +- csrc/quantization/awq/gemm_kernels.cu | 5 +- .../cutlass_w4a8/get_group_starts.cuh | 4 +- .../cutlass_w4a8/w4a8_grouped_mm_entry.cu | 5 +- .../cutlass_w4a8/w4a8_mm_entry.cu | 8 +- csrc/quantization/cutlass_w4a8/w4a8_utils.cu | 5 +- .../activation_nvfp4_quant_fusion_kernels.cu | 9 +- .../fp4/nvfp4_blockwise_moe_kernel.cu | 6 +- csrc/quantization/fp4/nvfp4_experts_quant.cu | 9 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 4 +- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 8 +- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 6 +- .../fp4/nvfp4_scaled_mm_kernels.cu | 5 +- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 4 +- csrc/quantization/fp4/nvfp4_utils.cuh | 4 +- ...fused_layernorm_dynamic_per_token_quant.cu | 5 +- csrc/quantization/gguf/gguf_kernel.cu | 5 +- csrc/quantization/gptq/matrix_view.cuh | 4 +- csrc/quantization/gptq/q_gemm.cu | 7 +- .../gptq_allspark/allspark_qgemm_w8a16.cu | 2 +- .../gptq_allspark/allspark_repack.cu | 2 +- .../gptq_allspark/allspark_utils.cuh | 7 +- csrc/quantization/gptq_marlin/marlin.cuh | 7 +- .../hadacore/hadamard_transform_cuda.cu | 8 +- .../machete/machete_prepacked_layout.cuh | 4 +- csrc/quantization/marlin/sparse/common/mma.h | 3 +- csrc/quantization/vectorization.cuh | 3 +- .../w8a8/cutlass/c3x/cutlass_gemm_caller.cuh | 4 +- .../w8a8/cutlass/c3x/scaled_mm_helper.hpp | 2 +- .../w8a8/cutlass/c3x/scaled_mm_kernels.hpp | 2 +- .../moe/blockwise_scaled_group_mm_sm100.cu | 9 +- .../w8a8/cutlass/moe/get_group_starts.cuh | 4 +- .../w8a8/cutlass/moe/grouped_mm_c3x.cuh | 2 + .../w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu | 4 +- .../w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu | 4 +- .../quantization/w8a8/cutlass/moe/moe_data.cu | 5 +- .../w8a8/cutlass/scaled_mm_c2x.cuh | 2 +- .../w8a8/cutlass/scaled_mm_entry.cu | 6 +- csrc/quantization/w8a8/fp8/common.cu | 5 +- .../w8a8/fp8/per_token_group_quant.cu | 3 +- .../w8a8/int8/per_token_group_quant.cu | 4 +- csrc/quantization/w8a8/int8/scaled_quant.cu | 6 +- .../w8a8/per_token_group_quant_8bit.h | 3 +- csrc/quickreduce/base.h | 7 +- csrc/quickreduce/quick_reduce.h | 2 +- csrc/quickreduce/quick_reduce_impl.cuh | 2 +- csrc/sampler.cu | 8 +- csrc/sparse/cutlass/sparse_compressor_c3x.cuh | 2 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 2 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 5 +- csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 4 +- csrc/type_convert.cuh | 12 +- csrc/vendors/cuda.h | 50 + csrc/vendors/functions.h | 9 + csrc/vendors/hip.h | 307 ++++++ csrc/vendors/musa.h | 181 ++++ 98 files changed, 1757 insertions(+), 1286 deletions(-) create mode 100644 cmake/comm.cmake create mode 100644 cmake/cuda.cmake create mode 100644 cmake/hip.cmake create mode 100644 cmake/musa.cmake create mode 100644 csrc/vendors/cuda.h create mode 100644 csrc/vendors/functions.h create mode 100644 csrc/vendors/hip.h create mode 100644 csrc/vendors/musa.h diff --git a/CMakeLists.txt b/CMakeLists.txt index cd52df8..b371b1b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,9 +36,6 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") -# Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") - # ROCm installation prefix. Default to /opt/rocm but allow override via # -DROCM_PATH=/your/rocm/path when invoking cmake. if(NOT DEFINED ROCM_PATH) @@ -46,18 +43,6 @@ if(NOT DEFINED ROCM_PATH) else() set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE) endif() -# -# Supported/expected torch versions for CUDA/ROCm. -# -# Currently, having an incorrect pytorch version results in a warning -# rather than an error. -# -# Note: the CUDA torch version is derived from pyproject.toml and various -# requirements.txt files and should be kept consistent. The ROCm torch -# versions are derived from docker/Dockerfile.rocm -# -set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") # # Try to find python package with an executable that exactly matches @@ -76,12 +61,6 @@ endif() # append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") -# Ensure the 'nvcc' command is in the PATH -find_program(NVCC_EXECUTABLE nvcc) -if (CUDA_FOUND AND NOT NVCC_EXECUTABLE) - message(FATAL_ERROR "nvcc not found") -endif() - # # Import torch cmake configuration. # Torch also imports CUDA (and partially HIP) languages with some customizations, @@ -90,18 +69,6 @@ endif() # find_package(Torch REQUIRED) -# Supported NVIDIA architectures. -# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined -if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND - CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) - set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0") -elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND - CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) - set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") -else() - set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") -endif() - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -119,152 +86,36 @@ endif() # Set up GPU language and check the torch version and warn if it isn't # what is expected. # -if (NOT HIP_FOUND AND CUDA_FOUND) - set(VLLM_GPU_LANG "CUDA") - - if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} " - "expected for CUDA build, saw ${Torch_VERSION} instead.") - endif() -elseif(HIP_FOUND) - set(VLLM_GPU_LANG "HIP") - - # Importing torch recognizes and sets up some HIP/ROCm configuration but does - # not let cmake recognize .hip files. In order to get cmake to understand the - # .hip extension automatically, HIP must be enabled explicitly. - enable_language(HIP) - - # ROCm 5.X and 6.X - if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND - Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM}) - message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " - "expected for ROCm build, saw ${Torch_VERSION} instead.") - endif() +if (VLLM_TARGET_DEVICE STREQUAL "cuda") + # Include CUDA specific configuration + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cuda.cmake) +elseif(VLLM_TARGET_DEVICE STREQUAL "rocm") + # Include ROCm specific configuration + include(${CMAKE_CURRENT_LIST_DIR}/cmake/hip.cmake) +elseif(VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) +elseif(VLLM_TARGET_DEVICE STREQUAL "musa") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/musa.cmake) else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - # - # For cuda we want to be able to control which architectures we compile for on - # a per-file basis in order to cut down on compile time. So here we extract - # the set of architectures we want to compile for and remove the from the - # CMAKE_CUDA_FLAGS so that they are not applied globally. - # - clear_cuda_arches(CUDA_ARCH_FLAGS) - extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") - message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") - # Filter the target architectures by the supported supported archs - # since for some files we will build for all CUDA_ARCHS. - cuda_archs_loose_intersection(CUDA_ARCHS - "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") - message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") -else() - # - # For other GPU targets override the GPU architectures detected by cmake/torch - # and filter them by the supported versions for the current language. - # The final set of arches is stored in `VLLM_GPU_ARCHES`. - # - override_gpu_arches(VLLM_GPU_ARCHES - ${VLLM_GPU_LANG} - "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") -endif() - -# -# Query torch for additional GPU compilation flags for the given -# `VLLM_GPU_LANG`. -# The final set of arches is stored in `VLLM_GPU_FLAGS`. -# -get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) - -# -# Set nvcc parallelism. -# -if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") -endif() - -# -# Set compression mode for CUDA >=13.x. -# -if(VLLM_GPU_LANG STREQUAL "CUDA" AND - DEFINED CMAKE_CUDA_COMPILER_VERSION AND - CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) - list(APPEND VLLM_GPU_FLAGS "--compress-mode=size") -endif() - -# -# Set CUDA include flags for CXX compiler. -# -if(VLLM_GPU_LANG STREQUAL "CUDA") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") - if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") - endif() -endif() - -# # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. -# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. -# Each dependency that produces build artifacts should override its BINARY_DIR to avoid -# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/. -# include(FetchContent) file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") -if(VLLM_GPU_LANG STREQUAL "HIP") - # - # Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info - # - set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") - - # - # Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates - # a lot of warnings that always mask real issues. Suppressing until this is properly addressed. - # - set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") -endif() - # # Define other extension targets # # # cumem_allocator extension +# Architecture-specific cumem configurations are included from cmake/cuda.cmake or cmake/hip.cmake # -set(VLLM_CUMEM_EXT_SRC - "csrc/cumem_allocator.cpp") - -set_gencode_flags_for_srcs( - SRCS "${VLLM_CUMEM_EXT_SRC}" - CUDA_ARCHS "${CUDA_ARCHS}") - if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling cumem allocator extension.") - if(VLLM_GPU_LANG STREQUAL "CUDA") - # link against cuda driver library - list(APPEND CUMEM_LIBS CUDA::cuda_driver) - else() - # link against rocm driver library. Prefer an absolute path to - # libamdhip64.so inside ${ROCM_PATH}/lib if available, otherwise fall - # back to linking by name "amdhip64". - find_library(AMDHIP64_LIB - NAMES amdhip64 libamdhip64.so - PATHS ${ROCM_PATH}/lib - NO_DEFAULT_PATH) - if(AMDHIP64_LIB) - message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}") - list(APPEND CUMEM_LIBS ${AMDHIP64_LIB}) - else() - message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name") - list(APPEND CUMEM_LIBS amdhip64) - endif() - endif() define_extension_target( cumem_allocator DESTINATION vllm @@ -279,647 +130,7 @@ endif() # _C extension # -set(VLLM_EXT_SRC - "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/cache_kernels.cu" - "csrc/attention/paged_attention_v1.cu" - "csrc/attention/paged_attention_v2.cu" - "csrc/attention/merge_attn_states.cu" - "csrc/attention/vertical_slash_index.cu" - "csrc/pos_encoding_kernels.cu" - "csrc/activation_kernels.cu" - "csrc/layernorm_kernels.cu" - "csrc/fused_qknorm_rope_kernel.cu" - "csrc/layernorm_quant_kernels.cu" - "csrc/sampler.cu" - "csrc/cuda_view.cu" - "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/w8a8/int8/scaled_quant.cu" - "csrc/quantization/w8a8/fp8/common.cu" - "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" - "csrc/quantization/gguf/gguf_kernel.cu" - "csrc/quantization/activation_kernels.cu" - "csrc/cuda_utils_kernels.cu" - "csrc/custom_all_reduce.cu" - "csrc/torch_bindings.cpp") - -if(VLLM_GPU_LANG STREQUAL "CUDA") - SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") - - # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v4.2.1") - - # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided - if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) - set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) - endif() - - if(VLLM_CUTLASS_SRC_DIR) - if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) - get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) - endif() - message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") - FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) - else() - FetchContent_Declare( - cutlass - GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG ${CUTLASS_REVISION} - GIT_PROGRESS TRUE - - # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. - # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. - # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE - ) - endif() - FetchContent_MakeAvailable(cutlass) - - list(APPEND VLLM_EXT_SRC - "csrc/quantization/awq/gemm_kernels.cu" - "csrc/permute_cols.cu" - "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" - "csrc/quantization/fp4/nvfp4_quant_entry.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" - "csrc/quantization/w8a8/int8/per_token_group_quant.cu") - - set_gencode_flags_for_srcs( - SRCS "${VLLM_EXT_SRC}" - CUDA_ARCHS "${CUDA_ARCHS}") - - # Only build Marlin kernels if we are building for at least some compatible archs. - # Keep building Marlin for 9.0 as there are some group sizes and shapes that - # are not supported by Machete yet. - - # marlin arches for fp16 output - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") - # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) - cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") - # marlin arches for fp8 input - # - sm80 doesn't support fp8 computation - # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction - # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") - - if (MARLIN_ARCHS) - - # - # For the Marlin kernels we automatically generate sources for various - # preselected input type pairs and schedules. - # Generate sources: - set(MARLIN_GEN_SCRIPT - ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) - file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) - list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) - set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - - message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - - if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} - OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) - execute_process( - COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$ENV{PYTHONPATH} - ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} - RESULT_VARIABLE marlin_generation_result - OUTPUT_VARIABLE marlin_generation_result - OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log - ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log - ) - - if (NOT marlin_generation_result EQUAL 0) - message(FATAL_ERROR "Marlin generation failed." - " Result: \"${marlin_generation_result}\"" - "\nCheck the log for details: " - "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") - else() - set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} - CACHE STRING "Last run Marlin generate script hash and arch" FORCE) - message(STATUS "Marlin generation completed successfully.") - endif() - else() - message(STATUS "Marlin generation script has not changed, skipping generation.") - endif() - - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" - CUDA_ARCHS "${MARLIN_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) - - file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" - CUDA_ARCHS "${MARLIN_BF16_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) - - if (MARLIN_FP8_ARCHS) - file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" - CUDA_ARCHS "${MARLIN_FP8_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) - endif() - - set(MARLIN_SRCS - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/gptq_marlin/gptq_marlin.cu" - "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" - "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_SRCS}" - CUDA_ARCHS "${MARLIN_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") - - message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") - else() - message(STATUS "Not building Marlin kernels as no compatible archs found" - " in CUDA target architectures") - endif() - - # Only build AllSpark kernels if we are building for at least some compatible archs. - cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") - if (ALLSPARK_ARCHS) - set(ALLSPARK_SRCS - "csrc/quantization/gptq_allspark/allspark_repack.cu" - "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") - set_gencode_flags_for_srcs( - SRCS "${ALLSPARK_SRCS}" - CUDA_ARCHS "${ALLSPARK_ARCHS}") - list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") - message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") - else() - message(STATUS "Not building AllSpark kernels as no compatible archs found" - " in CUDA target architectures") - endif() - - - set(SCALED_MM_3X_ARCHS) - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.0 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " - "Hopper.") - else() - message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) - # require CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # - # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) - # kernels for the remaining archs that are not already built for 3x. - # (Build 8.9 for FP8) - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") - # subtract out the archs that are already built for 3x - list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) - if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") - message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") - else() - if (SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c2x as all archs are already built" - " for and covered by scaled_mm_c3x") - else() - message(STATUS "Not building scaled_mm_c2x as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # - # 2:4 Sparse Kernels - - # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper). - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) - set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") - message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) - message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " - "if you intend on running FP8 sparse quantized models on Hopper.") - else() - message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require - # CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) - set(SRCS - "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_experts_quant.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${FP4_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") - message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") - else() - message(STATUS "Not building NVFP4 as no compatible archs were found.") - # clear FP4_ARCHS - set(FP4_ARCHS) - endif() - - # FP4 Archs and flags - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) - set(SRCS - "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_experts_quant.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${FP4_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") - else() - message(STATUS "Not building NVFP4 as no compatible archs were found.") - # clear FP4_ARCHS - set(FP4_ARCHS) - endif() - - # CUTLASS MLA Archs and flags - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) - set(SRCS - "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${MLA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") - # Add MLA-specific include directories only to MLA source files - set_source_files_properties(${SRCS} - PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") - message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") - else() - message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") - # clear MLA_ARCHS - set(MLA_ARCHS) - endif() - - # CUTLASS MoE kernels - - # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works - # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled - # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Blackwell.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - # moe_data.cu is used by all CUTLASS MoE kernels. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - message(STATUS "Not building moe_data as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") - else() - message(STATUS "Not building moe_data as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Blackwell.") - else() - message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # - # Machete kernels - - # The machete kernels only work on hopper and require CUDA 12.0 or later. - # Only build Machete kernels if we are building for something compatible with sm90a - cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) - # - # For the Machete kernels we automatically generate sources for various - # preselected input type pairs and schedules. - # Generate sources: - set(MACHETE_GEN_SCRIPT - ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) - file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) - - message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") - message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") - - if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} - OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) - execute_process( - COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH} - ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} - RESULT_VARIABLE machete_generation_result - OUTPUT_VARIABLE machete_generation_output - OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log - ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log - ) - - if (NOT machete_generation_result EQUAL 0) - message(FATAL_ERROR "Machete generation failed." - " Result: \"${machete_generation_result}\"" - "\nCheck the log for details: " - "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") - else() - set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} - CACHE STRING "Last run machete generate script hash" FORCE) - message(STATUS "Machete generation completed successfully.") - endif() - else() - message(STATUS "Machete generation script has not changed, skipping generation.") - endif() - - # Add machete generated sources - file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") - list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) - - # forward compatible - set_gencode_flags_for_srcs( - SRCS "${MACHETE_GEN_SOURCES}" - CUDA_ARCHS "${MACHETE_ARCHS}") - - list(APPEND VLLM_EXT_SRC - csrc/quantization/machete/machete_pytorch.cu) - - message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 - AND MACHETE_ARCHS) - message(STATUS "Not building Machete kernels as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running w4a16 quantized models on " - "Hopper.") - else() - message(STATUS "Not building Machete kernels as no compatible archs " - "found in CUDA target architectures") - endif() - endif() - - # Only build W4A8 kernels if we are building for something compatible with sm90a - cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) - set(SRCS - "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" - "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" - "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" - ) - - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${W4A8_ARCHS}") - - list(APPEND VLLM_EXT_SRC "${SRCS}") - - message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 - AND W4A8_ARCHS) - message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running w4a16 quantized models on " - "Hopper.") - else() - message(STATUS "Not building W4A8 kernels as no compatible archs " - "found in CUDA target architectures") - endif() - endif() - - # Hadacore kernels - cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") - if(HADACORE_ARCHS) - set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${HADACORE_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building hadacore") - endif() - -# if CUDA endif -endif() - -if (VLLM_GPU_LANG STREQUAL "HIP") - # Add QuickReduce kernels - list(APPEND VLLM_EXT_SRC - "csrc/custom_quickreduce.cu" - ) -# if ROCM endif -endif() +# VLLM_EXT_SRC is defined in the architecture-specific cmake files (cuda.cmake or hip.cmake) message(STATUS "Enabling C extension.") define_extension_target( @@ -940,121 +151,8 @@ define_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) -# # _moe_C extension -# - -set(VLLM_MOE_EXT_SRC - "csrc/moe/torch_bindings.cpp" - "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/topk_softmax_kernels.cu") - -if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC - "csrc/moe/moe_wna16.cu" - "csrc/moe/grouped_topk_kernels.cu") -endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - set(MOE_PERMUTE_SRC - "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" - "csrc/moe/moe_permute_unpermute_op.cu") - - list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") -endif() - -set_gencode_flags_for_srcs( - SRCS "${VLLM_MOE_EXT_SRC}" - CUDA_ARCHS "${CUDA_ARCHS}") - -if(VLLM_GPU_LANG STREQUAL "CUDA") - set(VLLM_MOE_WNA16_SRC - "csrc/moe/moe_wna16.cu") - - set_gencode_flags_for_srcs( - SRCS "${VLLM_MOE_WNA16_SRC}" - CUDA_ARCHS "${CUDA_ARCHS}") - - list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - # moe marlin arches - # note that we always set `use_atomic_add=False` for moe marlin now, - # so we don't need 9.0 for bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") - # moe marlin arches for fp8 input - # - sm80 doesn't support fp8 computation - # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction - # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") - if (MARLIN_MOE_ARCHS) - - # - # For the Marlin MOE kernels we automatically generate sources for various - # preselected input type pairs and schedules. - # Generate sources: - set(MOE_MARLIN_GEN_SCRIPT - ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) - file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) - list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) - set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - - message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - - if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} - OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) - execute_process( - COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$ENV{PYTHONPATH} - ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} - RESULT_VARIABLE moe_marlin_generation_result - OUTPUT_VARIABLE moe_marlin_generation_output - OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log - ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log - ) - - if (NOT moe_marlin_generation_result EQUAL 0) - message(FATAL_ERROR "Marlin MOE generation failed." - " Result: \"${moe_marlin_generation_result}\"" - "\nCheck the log for details: " - "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") - else() - set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} - CACHE STRING "Last run Marlin MOE generate script hash" FORCE) - message(STATUS "Marlin MOE generation completed successfully.") - endif() - else() - message(STATUS "Marlin MOE generation script has not changed, skipping generation.") - endif() - - file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") - list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_SRC}" - CUDA_ARCHS "${MARLIN_MOE_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_MOE_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) - - if (MARLIN_MOE_FP8_ARCHS) - file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") - set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_FP8_SRC}" - CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MARLIN_MOE_FP8_SRC} - PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") - endif() - list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) - endif() - - message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") - else() - message(STATUS "Not building Marlin MOE kernels as no compatible archs found" - " in CUDA target architectures") - endif() -endif() +# Architecture-specific MOE configurations are included from cmake/cuda.cmake or cmake/hip.cmake message(STATUS "Enabling moe extension.") define_extension_target( @@ -1069,25 +167,7 @@ define_extension_target( USE_SABI 3 WITH_SOABI) -if(VLLM_GPU_LANG STREQUAL "HIP") - # - # _rocm_C extension - # - set(VLLM_ROCM_EXT_SRC - "csrc/rocm/torch_bindings.cpp" - "csrc/rocm/skinny_gemms.cu" - "csrc/rocm/attention.cu") - - define_extension_target( - _rocm_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_ROCM_EXT_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -endif() +# Architecture-specific ROCm configurations are included from cmake/hip.cmake # For CUDA and HIP builds also build the triton_kernels external package. if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") diff --git a/cmake/comm.cmake b/cmake/comm.cmake new file mode 100644 index 0000000..4d65ab4 --- /dev/null +++ b/cmake/comm.cmake @@ -0,0 +1,23 @@ +set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/cache_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" + "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" + "csrc/cuda_view.cu" + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" + "csrc/cuda_utils_kernels.cu" + "csrc/custom_all_reduce.cu" + "csrc/torch_bindings.cpp") \ No newline at end of file diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake new file mode 100644 index 0000000..c227a49 --- /dev/null +++ b/cmake/cuda.cmake @@ -0,0 +1,753 @@ +# +# CUDA-specific configuration for vLLM +# + +set(VLLM_GPU_LANG "CUDA") + +# Set the supported torch version for CUDA +set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") + +# Warn if the torch version doesn't match what we expect +if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} " + "expected for CUDA build, saw ${Torch_VERSION} instead.") +endif() + +# Extract and filter CUDA architectures +clear_cuda_arches(CUDA_ARCH_FLAGS) +extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") +message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") + +# Filter the target architectures by the supported archs +cuda_archs_loose_intersection(CUDA_ARCHS + "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") +message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") + +# Query torch for additional GPU compilation flags +set(VLLM_GPU_ARCHES "${CUDA_ARCHS}") +get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) + +# Set nvcc parallelism +if(NVCC_THREADS) + list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") +endif() + +# Set compression mode for CUDA >=13.x +if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + list(APPEND VLLM_GPU_FLAGS "--compress-mode=size") +endif() + +# Set CUDA include flags for CXX compiler +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") +if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") +endif() + +# Set up CUTLASS for CUDA builds +SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + +# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. +set(CUTLASS_REVISION "v4.2.1") + +# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided +if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) +endif() + +if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) +else() + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) +endif() +FetchContent_MakeAvailable(cutlass) + +# Set CUDA extension sources +set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/cache_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" + "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" + "csrc/cuda_view.cu" + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" + "csrc/cuda_utils_kernels.cu" + "csrc/custom_all_reduce.cu" + "csrc/torch_bindings.cpp") + +# Add CUDA-specific sources +list(APPEND VLLM_EXT_SRC + "csrc/quantization/awq/gemm_kernels.cu" + "csrc/permute_cols.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/cutlass_extensions/common.cpp" + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +# Marlin kernels configuration +# marlin arches for fp16 output +cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") +# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) +cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") +# marlin arches for fp8 input +# - sm80 doesn't support fp8 computation +# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction +# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) +cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + +if (MARLIN_ARCHS) + # Generate Marlin kernel sources + set(MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) + file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") + + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} + RESULT_VARIABLE marlin_generation_result + OUTPUT_VARIABLE marlin_generation_result + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ) + + if (NOT marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin generation failed." + " Result: \"${marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") + else() + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin generate script hash" FORCE) + message(STATUS "Marlin generation completed successfully.") + endif() + else() + message(STATUS "Marlin generation script has not changed, skipping generation.") + endif() + + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + + if (MARLIN_FP8_ARCHS) + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) + endif() + + set(MARLIN_SRCS + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_SRCS}" + CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") +else() + message(STATUS "Not building Marlin kernels as no compatible archs found" + " in CUDA target architectures") +endif() + +# AllSpark kernels configuration +cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") +if (ALLSPARK_ARCHS) + set(ALLSPARK_SRCS + "csrc/quantization/gptq_allspark/allspark_repack.cu" + "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${ALLSPARK_SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") +else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") +endif() + +# Scaled MM 3X (Hopper) kernels +set(SCALED_MM_3X_ARCHS) +cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# Scaled MM 3X (Geforce Blackwell SM120) kernels +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# Scaled MM 3X (Blackwell SM100) kernels +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# Scaled MM 2X kernels for remaining archs +cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") +# subtract out the archs that are already built for 3x +list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) +if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") +else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# 2:4 Sparse Kernels +cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# NVFP4 kernels for Geforce Blackwell SM120 +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") +else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) +endif() + +# NVFP4 kernels for other Blackwell architectures +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") +else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) +endif() + +# CUTLASS MLA kernels +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") +else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) +endif() + +# CUTLASS MoE kernels for Hopper +cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() +endif() + +# CUTLASS MoE kernels for Blackwell SM100 +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() +endif() + +# MoE data kernel (used by all CUTLASS MoE kernels) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() +endif() + +# Blockwise scaled group MM for Blackwell SM100 +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " + "in CUDA target architectures") + endif() +endif() + +# Machete kernels +cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) + # Generate Machete kernel sources + set(MACHETE_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) + file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) + + message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") + message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} + OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} + CACHE STRING "Last run machete generate script hash" FORCE) + message(STATUS "Machete generation completed successfully.") + endif() + else() + message(STATUS "Machete generation script has not changed, skipping generation.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + + # forward compatible + set_gencode_flags_for_srcs( + SRCS "${MACHETE_GEN_SOURCES}" + CUDA_ARCHS "${MACHETE_ARCHS}") + + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) + + message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND MACHETE_ARCHS) + message(STATUS "Not building Machete kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building Machete kernels as no compatible archs " + "found in CUDA target architectures") + endif() +endif() + +# W4A8 kernels +cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" + ) + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") +else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() +endif() + +# Hadacore kernels +cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") +if(HADACORE_ARCHS) + set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") +endif() + +# MOE extension sources for CUDA +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/topk_softmax_kernels.cu") + +list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") + +set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + +list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +# Marlin MOE kernels +# note that we always set `use_atomic_add=False` for moe marlin now, +# so we don't need 9.0 for bf16 atomicAdd PTX +cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") +# moe marlin arches for fp8 input +# - sm80 doesn't support fp8 computation +# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction +# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) +cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") +if (MARLIN_MOE_ARCHS) + # Generate Marlin MOE kernel sources + set(MOE_MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) + file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") + + message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} + RESULT_VARIABLE moe_marlin_generation_result + OUTPUT_VARIABLE moe_marlin_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ) + + if (NOT moe_marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin MOE generation failed." + " Result: \"${moe_marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") + else() + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin MOE generate script hash" FORCE) + message(STATUS "Marlin MOE generation completed successfully.") + endif() + else() + message(STATUS "Marlin MOE generation script has not changed, skipping generation.") + endif() + + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) + + if (MARLIN_MOE_FP8_ARCHS) + file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_FP8_SRC}" + CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_FP8_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) + endif() + + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") +else() + message(STATUS "Not building Marlin MOE kernels as no compatible archs found" + " in CUDA target architectures") +endif() + +# Cumem allocator for CUDA +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +# Link against cuda driver library for cumem +list(APPEND CUMEM_LIBS CUDA::cuda_driver) diff --git a/cmake/hip.cmake b/cmake/hip.cmake new file mode 100644 index 0000000..b2266d4 --- /dev/null +++ b/cmake/hip.cmake @@ -0,0 +1,146 @@ +# +# HIP/ROCm-specific configuration for vLLM +# + +set(VLLM_GPU_LANG "HIP") + +# Set the supported torch version for ROCm +set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") + +# Warn if the torch version doesn't match what we expect +if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND + Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM}) + message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " + "expected for ROCm build, saw ${Torch_VERSION} instead.") +endif() + +# Enable HIP language support +# Importing torch recognizes and sets up some HIP/ROCm configuration but does +# not let cmake recognize .hip files. In order to get cmake to understand the +# .hip extension automatically, HIP must be enabled explicitly. +enable_language(HIP) + +# Supported AMD GPU architectures +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") + +# Override GPU architectures detected by cmake/torch and filter by supported versions +override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") + +# Query torch for additional GPU compilation flags +get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) + +# Overriding the default -O set up by cmake, adding ggdb3 for the most verbose debug info +set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") + +# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates +# a lot of warnings that always mask real issues. Suppressing until this is properly addressed. +set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") + +# Set up CUTLASS for HIP builds +SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + +# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. +set(CUTLASS_REVISION "v4.2.1") + +# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided +if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) +endif() + +if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) +else() + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) +endif() +FetchContent_MakeAvailable(cutlass) + +# Set HIP extension sources +set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/cache_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" + "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" + "csrc/cuda_view.cu" # Note: Keeping this name for compatibility + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" + "csrc/cuda_utils_kernels.cu" # Note: Keeping this name for compatibility + "csrc/custom_all_reduce.cu" + "csrc/torch_bindings.cpp") + +# Add QuickReduce kernels for ROCm +list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" +) + +# MOE extension sources for ROCm +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/topk_softmax_kernels.cu") + +# Cumem allocator for ROCm +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +# Link against rocm driver library for cumem +# Prefer an absolute path to libamdhip64.so inside ${ROCM_PATH}/lib if available, +# otherwise fall back to linking by name "amdhip64". +find_library(AMDHIP64_LIB + NAMES amdhip64 libamdhip64.so + PATHS ${ROCM_PATH}/lib + NO_DEFAULT_PATH) +if(AMDHIP64_LIB) + message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}") + list(APPEND CUMEM_LIBS ${AMDHIP64_LIB}) +else() + message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name") + list(APPEND CUMEM_LIBS amdhip64) +endif() + +# ROCm-specific extension sources +set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/skinny_gemms.cu" + "csrc/rocm/attention.cu") + +# Define ROCm-specific extension target +define_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) diff --git a/cmake/musa.cmake b/cmake/musa.cmake new file mode 100644 index 0000000..4422e88 --- /dev/null +++ b/cmake/musa.cmake @@ -0,0 +1,125 @@ +# +# MUSA-specific configuration for vLLM +# + +set(VLLM_GPU_LANG "MUSA") + +# Set the supported torch version for MUSA +set(TORCH_SUPPORTED_VERSION_MUSA "2.7.1") + +# Warn if the torch version doesn't match what we expect +if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_MUSA}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_MUSA} " + "expected for MUSA build, saw ${Torch_VERSION} instead.") +endif() + +# Find MUSA package +list(APPEND CMAKE_MODULE_PATH $ENV{MUSA_HOME}/cmake) +find_package(MUSA REQUIRED) + +# Extract and filter MUSA architectures +# MUSA architectures are similar to CUDA, but may have different naming +message(STATUS "MUSA target architectures: ${MUSA_ARCHS}") + +# Filter the target architectures by the supported archs +# MUSA SDK 4.3.0 supports the following architectures +set(MUSA_SUPPORTED_ARCHS "21;22") + +# Override GPU architectures detected by cmake/torch +override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${MUSA_SUPPORTED_ARCHS}") + +# Query torch for additional GPU compilation flags +get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) + +# Set nvcc parallelism (MUSA compiler also supports --threads flag) +if(NVCC_THREADS) + list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") +endif() + +# Set MUSA include flags for CXX compiler +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I$ENV{MUSA_HOME}/include") + +# Set up CUTLASS for MUSA builds +# MUSA is compatible with CUDA, so we can use the same CUTLASS configuration +SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + +# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. +set(CUTLASS_REVISION "v4.2.1") + +# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided +if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) +endif() + +if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) +else() + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) +endif() +FetchContent_MakeAvailable(cutlass) + +# Set MUSA extension sources +# These are the same source files as CUDA, since MUSA is compatible with CUDA code +set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/cache_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" + "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" + "csrc/cuda_view.cu" # Note: Keeping this name for compatibility + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" + "csrc/cuda_utils_kernels.cu" # Note: Keeping this name for compatibility + "csrc/custom_all_reduce.cu" + "csrc/torch_bindings.cpp") + +# MOE extension sources for MUSA +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/topk_softmax_kernels.cu") + +list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") + +set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + +list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") + +# Cumem allocator for MUSA +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +# Link against musa driver library for cumem +list(APPEND CUMEM_LIBS musa::musa_driver) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index a4a880f..d61de1f 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,5 @@ -#include -#include -#include +#include "vendors/functions.h" + #include diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 052ff16..bd6bf81 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -17,9 +17,7 @@ * limitations under the License. */ -#include -#include -#include +#include "../vendors/functions.h" #include #include "attention_dtypes.h" @@ -27,9 +25,7 @@ #include "../cuda_compat.h" #ifdef USE_ROCM - #include #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" -typedef __hip_bfloat16 __nv_bfloat16; #else #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 97a25ba..a39ca8e 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -23,16 +23,8 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#ifndef USE_ROCM - #include - #include -#else - #include - #include -typedef __hip_bfloat162 __nv_bfloat162; -typedef __hip_bfloat16 __nv_bfloat16; -#endif +#include "../vendors/functions.h" #include diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 3a1815f..0e675c1 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -23,9 +23,7 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#ifdef USE_ROCM - #include -#endif +#include "../vendors/functions.h" #include diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index e714e32..7b4252e 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,11 +3,7 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8 - #ifndef USE_ROCM - #include - #endif // USE_ROCM -#endif // ENABLE_FP8 +#include "../vendors/functions.h" namespace vllm { diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 27d1e99..6f4bf38 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -1,7 +1,5 @@ #include -#include -#include -#include +#include "../vendors/functions.h" #include #include "attention_dtypes.h" diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu index c1b45b1..852b116 100644 --- a/csrc/attention/vertical_slash_index.cu +++ b/csrc/attention/vertical_slash_index.cu @@ -3,9 +3,7 @@ #include -#include - -#include +#include "../vendors/functions.h" __device__ int64_t save_blocks(int* block_offset, int64_t range_start, int64_t range_end, int64_t block_size, diff --git a/csrc/cache.h b/csrc/cache.h index cbe44c0..f210bd6 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include "vendors/functions.h" + #include #include diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f11c5f2..dcafc3a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,8 +1,6 @@ -#include -#include -#include -#include -#include +#include "vendors/functions.h" + + #include "cuda_utils.h" #include "cuda_compat.h" @@ -19,10 +17,7 @@ #include #include -#ifdef USE_ROCM - #include -typedef __hip_bfloat16 __nv_bfloat16; -#endif + void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping) { diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h index 18e4e34..4b283ba 100644 --- a/csrc/cub_helpers.h +++ b/csrc/cub_helpers.h @@ -1,18 +1,4 @@ #pragma once -#ifndef USE_ROCM - #include - #if CUB_VERSION >= 200800 - #include -using CubAddOp = cuda::std::plus<>; -using CubMaxOp = cuda::maximum<>; - #else // if CUB_VERSION < 200800 -using CubAddOp = cub::Sum; -using CubMaxOp = cub::Max; - #endif // CUB_VERSION -#else - #include -namespace cub = hipcub; -using CubAddOp = hipcub::Sum; -using CubMaxOp = hipcub::Max; -#endif // USE_ROCM +#include "vendors/functions.h" + diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 0627a42..c1d9fd3 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,8 +1,5 @@ #include "cuda_utils.h" -#ifdef USE_ROCM - #include - #include -#endif +#include "vendors/functions.h" int64_t get_device_attribute(int64_t attribute, int64_t device_id) { // Return the cached value on subsequent calls diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu index 9853fc9..755e4e2 100644 --- a/csrc/cuda_view.cu +++ b/csrc/cuda_view.cu @@ -1,6 +1,5 @@ -#include -#include -#include +#include "vendors/functions.h" + // This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned // memory, and that UVA (Unified Virtual Addressing) is enabled. diff --git a/csrc/cumem_allocator_compat.h b/csrc/cumem_allocator_compat.h index 74f4bc9..a29d2cc 100644 --- a/csrc/cumem_allocator_compat.h +++ b/csrc/cumem_allocator_compat.h @@ -104,6 +104,6 @@ CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) { //////////////////////////////////////// // Import CUDA headers for NVIDIA GPUs //////////////////////////////////////// - #include - #include +#include "vendors/functions.h" + #endif diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index a38d6fa..a6eb2f0 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,7 +1,5 @@ -#include -#include -#include -#include +#include "vendors/functions.h" + #include "custom_all_reduce.cuh" diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 58926f6..4ec3799 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -1,13 +1,7 @@ #pragma once -#include -#include -#include -#include +#include "vendors/functions.h" -#if defined(USE_ROCM) -typedef __hip_bfloat16 nv_bfloat16; -#endif #include #include diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7f0823..4a7c2a5 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -11,8 +11,8 @@ * To run: * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ -#include -#include +#include "vendors/functions.h" + #include #include diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 33d0d4a..5bcfa6a 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -1,7 +1,5 @@ -#include -#include -#include -#include +#include "vendors/functions.h" + #ifdef USE_ROCM diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index de0c505..0a2e5de 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,7 +4,8 @@ */ #pragma once -#include +#include "vendors/functions.h" + // Need a special dispatch case macro since we will nest the FP8 dispatch. // Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu index baff836..6786efc 100644 --- a/csrc/fused_qknorm_rope_kernel.cu +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -15,11 +15,10 @@ */ #include -#include #include -#include -#include +#include "vendors/functions.h" + #include "cuda_compat.h" #include "dispatch_utils.h" diff --git a/csrc/launch_bounds_utils.h b/csrc/launch_bounds_utils.h index 92d7ef8..0d17ee8 100644 --- a/csrc/launch_bounds_utils.h +++ b/csrc/launch_bounds_utils.h @@ -1,6 +1,7 @@ #pragma once -#include +#include "vendors/functions.h" + #include // maximum blocks per SM cap diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index dfc67b9..8817c5e 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,8 +4,8 @@ #include "core/batch_invariant.hpp" #include "quantization/vectorization_utils.cuh" -#include -#include +#include "vendors/functions.h" + namespace vllm { diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0880b8d..c62e371 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -12,8 +12,8 @@ #include "core/batch_invariant.hpp" #include "quantization/vectorization_utils.cuh" -#include -#include +#include "vendors/functions.h" + namespace vllm { diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 7d22dd8..ad04829 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -6,12 +6,7 @@ #pragma once -#ifndef USE_ROCM - #include -#else - #include -#endif -#include +#include "vendors/functions.h" //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMParamsBase { diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index fb2a2e5..622893d 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -1,27 +1,9 @@ // clang-format off // adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh -#include -#include -#include + +#include "vendors/functions.h" + #include "selective_scan.h" - -#include -#include -#ifdef USE_ROCM - #include // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK -#else - #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK -#endif - -#ifndef USE_ROCM - #include - #include - #include -#else - #include - namespace cub = hipcub; -#endif - #include "selective_scan.h" #include "static_switch.h" diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index 58dc402..baa3434 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -1,7 +1,4 @@ -#include -#include -#include - +#include "../../vendors/functions.h" // _dyn_quant_matmul_4bit is only available on AArch64. #if defined(__aarch64__) #include diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 5fa367a..9cb0c59 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -17,13 +17,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include + +#include "../../vendors/functions.h" namespace cg = cooperative_groups; namespace vllm { diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 5c9e474..484fc51 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,10 +1,4 @@ -#include -#include -#include -#include - -#include -#include +#include "../vendors/functions.h" #include "../cuda_compat.h" #include "../dispatch_utils.h" diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 337dcc5..daeb313 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../vendors/functions.h" void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index ca0c873..891ff0c 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -1,6 +1,4 @@ -#include -#include -#include +#include "../vendors/functions.h" #include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h" #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 7b6a111..e3cb529 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -1,11 +1,7 @@ -#include -#include -#include -#include +#include "../vendors/functions.h" + -#include -#include #include "moe_wna16_utils.h" #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index 8ef03f0..14bd27c 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -1,6 +1,5 @@ -#include -#include +#include "../vendors/functions.h" template class ScalarType {}; diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h index d0f1ea4..33f8bf3 100644 --- a/csrc/moe/permute_unpermute_kernels/dispatch.h +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -1,5 +1,5 @@ #pragma once -#include +#include "vendors/functions.h" #define MOE_SWITCH(TYPE, ...) \ at::ScalarType _st = ::detail::scalar_type(TYPE); \ switch (_st) { \ diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 108091e..846ea93 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -2,12 +2,11 @@ // reference from tensorrt_llm moe kernel implementation archive in // https://github.com/BBuf/tensorrt-llm-moe/tree/master -#include #include #include "dispatch.h" -#include -#include -#include + + +#include "../../vendors/functions.h" #include "cutlass/numeric_size.h" #include "cutlass/array.h" diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index af6e6fc..1a878a5 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -17,9 +17,9 @@ * limitations under the License. */ #include -#include -#include -#include +#include "../vendors/functions.h" + + #include "../cuda_compat.h" #include "../cub_helpers.h" diff --git a/csrc/ops.h b/csrc/ops.h index 37e3aaf..9212855 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,6 +1,7 @@ #pragma once #include + #include #include "core/scalar_type.hpp" diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu index f51fa73..58538e2 100644 --- a/csrc/permute_cols.cu +++ b/csrc/permute_cols.cu @@ -1,9 +1,6 @@ -#include +#include "vendors/functions.h" -#include -#include -#include static constexpr int default_threads = 256; static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b5645b3..c820ff2 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,5 @@ -#include -#include -#include +#include "vendors/functions.h" + #include "cuda_compat.h" #include "dispatch_utils.h" diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 0c3bcf3..97f7980 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -1,6 +1,5 @@ -#include -#include -#include + +#include "../vendors/functions.h" #include #include "core/math.hpp" @@ -9,29 +8,8 @@ #include "quantization/w8a8/fp8/common.cuh" -#include -#ifndef USE_ROCM - #include - #include - #include -#else - #include - #include - #include -typedef __hip_bfloat162 __nv_bfloat162; -typedef __hip_bfloat16 __nv_bfloat16; -typedef __hip_bfloat16_raw __nv_bfloat16_raw; - #if defined(HIP_FP8_TYPE_OCP) -typedef __hip_fp8_e4m3 __nv_fp8_e4m3; -typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; - #else -// ROCm 6.2 fallback: only *_fnuz types exist -typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3; -typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3; - #endif -#endif #include "core/registration.h" namespace vllm { diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 53c4767..bd296c4 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -7,12 +7,11 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include -#include + #include "dequantize.cuh" -#include +#include "../../vendors/functions.h" namespace vllm { namespace awq { diff --git a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh index fec142d..5ed3155 100644 --- a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh @@ -1,9 +1,7 @@ // see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh #pragma once -#include -#include -#include +#include "../../vendors/functions.h" #include "core/scalar_type.hpp" #include "cutlass/bfloat16.h" diff --git a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu index 4b42579..5beae42 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -14,9 +14,8 @@ #include "cutlass/util/mixed_dtype_utils.hpp" // vllm includes -#include -#include -#include +#include "../../vendors/functions.h" + #include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/common.hpp" diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index f77af06..1079335 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -3,9 +3,10 @@ // https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu // -#include -#include -#include + +#include "../../vendors/functions.h" + + #include "cutlass_extensions/torch_utils.hpp" #include "w4a8_utils.cuh" @@ -26,7 +27,6 @@ #include "cutlass_extensions/common.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include namespace vllm::cutlass_w4a8 { diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu index f238d0a..c80bea0 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu @@ -1,7 +1,10 @@ #include "w4a8_utils.cuh" + +#include "../../vendors/functions.h" + + #include -#include #include namespace vllm::cutlass_w4a8_utils { diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 7539f83..c0aac62 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -14,15 +14,10 @@ * limitations under the License. */ -#include -#include -#include + #include "../../vendors/functions.h" -#include -#include - -#include + #include "dispatch_utils.h" #include "cuda_utils.h" diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 6744402..4d87906 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -19,9 +19,9 @@ #include #include -#include -#include -#include +#include "../../vendors/functions.h" + + #include "cutlass_extensions/common.hpp" #include "cute/tensor.hpp" diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 82c53c2..db29c42 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -14,15 +14,10 @@ * limitations under the License. */ -#include -#include -#include + #include "../../vendors/functions.h" -#include -#include - -#include + #include "dispatch_utils.h" #include "nvfp4_utils.cuh" diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index fb6d22f..1bfaee1 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -14,7 +14,9 @@ * limitations under the License. */ -#include + +#include "../../vendors/functions.h" + #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 6d69852..ac7d204 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -14,15 +14,9 @@ * limitations under the License. */ -#include +#include "../../vendors/functions.h" -#include -#include -#include -#include - -#include #include "dispatch_utils.h" #include "cuda_utils.h" diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index d9c4d24..c5c3f0a 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -14,8 +14,10 @@ * limitations under the License. */ -#include -#include + +#include "../../vendors/functions.h" + + #include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 5bc4c38..498d3c5 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -14,10 +14,7 @@ * limitations under the License. */ -#include - -#include -#include +#include "../../vendors/functions.h" #include "cutlass_extensions/common.hpp" diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index 89de23b..66310a8 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -14,10 +14,8 @@ * limitations under the License. */ -#include +#include "../../vendors/functions.h" -#include -#include #include "cutlass_extensions/common.hpp" diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 48e4959..2f1b01a 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -16,8 +16,8 @@ #pragma once -#include -#include +#include "../../vendors/functions.h" + #define ELTS_PER_THREAD 8 diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 2080ef3..d3bce88 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -1,6 +1,7 @@ -#include -#include + +#include "../../vendors/functions.h" + #include "../../dispatch_utils.h" #include "layernorm_utils.cuh" diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 76fe73e..484f439 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -1,8 +1,5 @@ -#include -#include +#include "../../vendors/functions.h" -#include -#include #include "../../cuda_compat.h" #include "dispatch_utils.h" diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index 2b6719f..72459e2 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -6,8 +6,8 @@ https://github.com/turboderp/exllama #ifndef _matrix_view_cuh #define _matrix_view_cuh -#include -#include +#include "../../vendors/functions.h" + #include "qdq_util.cuh" diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 8869d7c..35e43c7 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -6,11 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include -#include -#include -#include -#include +#include "../../vendors/functions.h" + #include "compat.cuh" #include "matrix_view.cuh" diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index e306ff0..8b75fab 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -1,5 +1,5 @@ #include "allspark_utils.cuh" -#include +#include "../../vendors/functions.h" #include "core/registration.h" #include diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/quantization/gptq_allspark/allspark_repack.cu index 7a5b2f9..764f630 100644 --- a/csrc/quantization/gptq_allspark/allspark_repack.cu +++ b/csrc/quantization/gptq_allspark/allspark_repack.cu @@ -1,5 +1,5 @@ #include "allspark_utils.cuh" -#include +#include "../../vendors/functions.h" #include "core/registration.h" namespace allspark { diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 14a61ad..1c9d31e 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -1,11 +1,6 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include "../../vendors/functions.h" #include #include "../gptq_marlin/marlin_dtypes.cuh" using marlin::MarlinScalarType2; diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh index 2505e22..9a65c3b 100644 --- a/csrc/quantization/gptq_marlin/marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -1,12 +1,7 @@ #pragma once -#include +#include "../../vendors/functions.h" -#include -#include -#include -#include -#include #include #ifndef MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu index aff1132..a21a22c 100644 --- a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu +++ b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -11,15 +11,9 @@ Redistribution and use in source and binary forms, with or without modification, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ***********/ -#include +#include "../../../vendors/functions.h" #include -#include #include -#include -#include - -#include -#include #include "core/registration.h" #include "dispatch_utils.h" diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 4a7d634..ec51933 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -1,8 +1,6 @@ #pragma once -#include -#include -#include +#include "../../vendors/functions.h" // clang-format off // The cutlass include order matters (annoyingly) diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index b26505f..93bdae6 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -17,8 +17,7 @@ #pragma once #include "base.h" -#include - +#include "../../../../vendors/functions.h" namespace marlin_24 { // On CUDA earlier than 12.5, the ordered_metadata version of this instruction diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh index 11d57a5..d3b0eea 100644 --- a/csrc/quantization/vectorization.cuh +++ b/csrc/quantization/vectorization.cuh @@ -4,8 +4,7 @@ */ // Include both AMD and NVIDIA fp8 types to avoid circular import -#include -#include +#include "../vendors/functions.h" namespace vllm { diff --git a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh index 26de32c..7f2c454 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh @@ -2,9 +2,9 @@ // clang-format will break include orders // clang-format off -#include +#include "../../../../vendors/functions.h" + -#include #include "cutlass/cutlass.h" diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index 2204a49..37428bd 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -1,4 +1,4 @@ -#include +#include "../../../../vendors/functions.h" #include "cuda_utils.h" #include "cutlass_extensions/common.hpp" diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp index 9ceb3a3..adbb7ca 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../../../vendors/functions.h" namespace vllm { diff --git a/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu index 6c8f630..0c042df 100644 --- a/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu +++ b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu @@ -1,11 +1,10 @@ #include "core/registration.h" -#include -#include +// #include + +#include "../../../../vendors/functions.h" + -#include -#include -#include #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh index 15bb2c3..63c2e78 100644 --- a/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh +++ b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh @@ -1,8 +1,6 @@ #pragma once -#include -#include -#include +#include "../../../../vendors/functions.h" #include "core/scalar_type.hpp" #include "cutlass/bfloat16.h" diff --git a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh index 659941d..585b6e8 100644 --- a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh @@ -1,5 +1,7 @@ #pragma once +#include "../../../../vendors/functions.h" + #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/collective_builder.hpp" diff --git a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu index 641e599..cc1a805 100644 --- a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu +++ b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu @@ -1,7 +1,5 @@ -#include -#include -#include +#include "../../../../vendors/functions.h" #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" diff --git a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu index 8f21623..fa048ee 100644 --- a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu +++ b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu @@ -1,7 +1,5 @@ -#include -#include -#include +#include "../../../../vendors/functions.h" #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 99fec8f..40857e9 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -1,7 +1,4 @@ -#include - -#include -#include +#include "../../../../vendors/functions.h" #include diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh index ce7cf2f..46952b6 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh @@ -2,7 +2,7 @@ #include #include -#include +#include "../../../../vendors/functions.h" // clang-format will break include orders // clang-format off diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 5de21cf..d98e672 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -1,7 +1,7 @@ -#include -#include -#include +#include "../../../vendors/functions.h" + + #include "cutlass_extensions/common.hpp" diff --git a/csrc/quantization/w8a8/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu index 7a822fb..723b6ef 100644 --- a/csrc/quantization/w8a8/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -2,8 +2,9 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "quantization/vectorization_utils.cuh" -#include -#include +#include "../../../../vendors/functions.h" + + namespace vllm { diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index 49d1b20..64e2b0b 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,5 +1,4 @@ -#include - +#include "../../../vendors/functions.h" #include "quantization/w8a8/per_token_group_quant_8bit.h" #include diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu index 9d808a1..8c30991 100644 --- a/csrc/quantization/w8a8/int8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/int8/per_token_group_quant.cu @@ -1,6 +1,4 @@ -#include -#include - +#include "../../../vendors/functions.h" #include "quantization/w8a8/per_token_group_quant_8bit.h" void per_token_group_quant_int8(const torch::Tensor& input, diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu index be8ecfe..b0e12dc 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,6 +1,6 @@ -#include -#include -#include +#include "../../../vendors/functions.h" + + #include diff --git a/csrc/quantization/w8a8/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h index 25d4ecd..48b1d5f 100644 --- a/csrc/quantization/w8a8/per_token_group_quant_8bit.h +++ b/csrc/quantization/w8a8/per_token_group_quant_8bit.h @@ -1,6 +1,5 @@ #pragma once -#include - +#include "../../vendors/functions.h" // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index a2170e4..2b46dd6 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -1,9 +1,7 @@ #pragma once #include -#include -#include -#include +#include "../vendors/functions.h" #define __quickreduce_device_inline__ __device__ __forceinline__ #define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) @@ -11,8 +9,7 @@ namespace quickreduce { -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 4cc3530..196276f 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include "../vendors/functions.h" #include "quick_reduce_impl.cuh" #define HIP_CHECK(err) \ diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 38dc993..f6d84a7 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include "../vendors/functions.h" #include "base.h" namespace quickreduce { diff --git a/csrc/sampler.cu b/csrc/sampler.cu index fc2154b..b9f3822 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -1,13 +1,7 @@ #include "dispatch_utils.h" -#include -#include +#include "vendors/functions.h" -#ifndef USE_ROCM - #include -#else - #include -#endif namespace vllm { diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh index 2cc235f..6145911 100644 --- a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh @@ -2,7 +2,7 @@ // clang-format will break include orders // clang-format off -#include +#include "../../vendors/functions.h" #if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index d053ecc..cc7e17d 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -1,6 +1,6 @@ // clang-format will break include orders // clang-format off -#include +#include "../../vendors/functions.h" #if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 637bba1..3bd3c7a 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -2,11 +2,8 @@ // clang-format will break include orders // clang-format off -#include +#include "../../vendors/functions.h" -#include - -#include #include "cuda_utils.h" diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu index 38b929b..793ac17 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu @@ -1,7 +1,5 @@ -#include +#include "../../vendors/functions.h" -#include -#include #include "cutlass_extensions/common.hpp" diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 2678f69..323002c 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -1,17 +1,7 @@ #pragma once -#include +#include "vendors/functions.h" -#ifndef USE_ROCM - #include - #include -#else - #include - #include - -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; -#endif namespace vllm { /* Converter structs for the conversion from torch types to HIP/CUDA types, diff --git a/csrc/vendors/cuda.h b/csrc/vendors/cuda.h new file mode 100644 index 0000000..d904bc8 --- /dev/null +++ b/csrc/vendors/cuda.h @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if CUDART_VERSION >= 12050 +#include +#endif // CUDART_VERSION >= 12050 + +#if CUDART_VERSION < 11020 +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t +#endif // CUDART_VERSION < 11020 + +#if CUB_VERSION >= 200800 + #include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; + #else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; +#endif // CUB_VERSION \ No newline at end of file diff --git a/csrc/vendors/functions.h b/csrc/vendors/functions.h new file mode 100644 index 0000000..b994ce5 --- /dev/null +++ b/csrc/vendors/functions.h @@ -0,0 +1,9 @@ +#ifdef USE_MUSA +#include "musa.h" +#elif USE_HIP +#include "hip.h" +#elif USE_CUDA +#include "cuda.h" +#else +"No Support" +#endif diff --git a/csrc/vendors/hip.h b/csrc/vendors/hip.h new file mode 100644 index 0000000..94eaf0d --- /dev/null +++ b/csrc/vendors/hip.h @@ -0,0 +1,307 @@ +#pragma once +#include + +#define HIP_DISABLE_WARP_SYNC_BUILTINS 1 +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +namespace cub = hipcub; + +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; + +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_16BF HIPBLAS_R_16B +#define CUDA_R_32F HIPBLAS_R_32F +#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define __all_sync(mask, var) __all(var) +#define __any_sync(mask, var) __any(var) +#define cublasStrsmBatched hipblasStrsmBatched +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaHostRegister hipHostRegister +#define cudaHostRegisterPortable hipHostRegisterPortable +#define cudaHostRegisterReadOnly hipHostRegisterReadOnly +#define cudaHostUnregister hipHostUnregister +#define cudaLaunchHostFunc hipLaunchHostFunc +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMallocManaged hipMallocManaged +#define cudaMemAdvise hipMemAdvise +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemGetInfo hipMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamPerThread hipStreamPerThread +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResult hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#define __trap() do { abort(); __builtin_unreachable(); } while(0) +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + + +#define __ldg(arg) *(arg) + + +#if HIP_VERSION >= 60500000 +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F +#define cublasComputeType_t hipblasComputeType_t +#define cudaDataType_t hipDataType +#else +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define cublasComputeType_t hipblasDatatype_t +#define cudaDataType_t hipblasDatatype_t +#endif // HIP_VERSION >= 6050000 + +#if !defined(__HIP_PLATFORM_AMD__) +#error "The HIP backend supports only AMD targets" +#endif // !defined(__HIP_PLATFORM_AMD__) + +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx900__) || defined(__gfx906__) +#define GCN5 +#endif // defined(__gfx900__) || defined(__gfx906__) + +#if defined(__gfx803__) +#define GCN4 +#endif // defined(__gfx803__) + +#if defined(GCN5) || defined(GCN4) +#define GCN +#endif // defined(GCN5) || defined(GCN4) + +#if defined(__gfx942__) +#define CDNA3 +#endif // defined(__gfx942__) + +#if defined(__gfx90a__) +#define CDNA2 +#endif // defined(__gfx90a__) + +#if defined(__gfx908__) +#define CDNA1 +#endif // defined(__gfx908__) + +#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#define CDNA // For the entire family +#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) + +#if defined(__GFX12__) +#define RDNA4 +#endif // defined(__GFX12__) + +#if defined(__GFX11__) +#define RDNA3 +#endif // defined(__GFX11__) + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif // defined(__gfx1010__) || defined(__gfx1012__) + +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) +#define RDNA // For the entire family +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __vsub4(const int a, const int b) { + return __vsubss4(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } + return c; +} + + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16_raw __nv_bfloat16_raw; + #if defined(HIP_FP8_TYPE_OCP) +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; + #else +// ROCm 6.2 fallback: only *_fnuz types exist +typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3; + #include +namespace cub = hipcub; +using CubAddOp = hipcub::Sum; +using CubMaxOp = hipcub::Max; \ No newline at end of file diff --git a/csrc/vendors/musa.h b/csrc/vendors/musa.h new file mode 100644 index 0000000..9f1c6ac --- /dev/null +++ b/csrc/vendors/musa.h @@ -0,0 +1,181 @@ +// All header files + +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; + +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH +#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_16BF MUSA_R_16BF +#define CUDA_R_32F MUSA_R_32F +#define cublasStrsmBatched mublasStrsmBatched +#define cublasComputeType_t cudaDataType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t +#define cublasGetStatusString mublasGetStatusString +#define cudaDataType_t musaDataType_t +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// Additional mappings for MUSA virtual memory pool +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// Additional mappings for MUSA graphs +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResult musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamBeginCapture musaStreamBeginCapture +#define cudaStreamEndCapture musaStreamEndCapture +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor +#define __ldg(arg) *(arg) +typedef __mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat16 __nv_bfloat16; +typedef __mt_bfloat162 nv_bfloat162; +typedef __mt_bfloat162 __nv_bfloat162; +typedef __mt_bfloat162 __nv_bfloat162; +typedef __mt_bfloat16 __nv_bfloat16; +typedef __mt_bfloat16_raw __nv_bfloat16_raw; +typedef __mt_fp8_e4m3 __nv_fp8_e4m3; +typedef __mt_fp8x4_e4m3 __nv_fp8x4_e4m3;