From 0c9174108abe50bd65a30302d2b89bf86ab9cf3a Mon Sep 17 00:00:00 2001 From: Kangyan-Zhou Date: Sun, 28 Sep 2025 19:48:28 -0700 Subject: [PATCH] Unify SGL Kernel Releases (#10701) --- .github/workflows/pr-test.yml | 30 ++-- sgl-kernel/CMakeLists.txt | 62 ++++++-- sgl-kernel/python/sgl_kernel/__init__.py | 179 ++++++++++++++++++++++- 3 files changed, 241 insertions(+), 30 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6e8a041a1..bc7d48594 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -66,10 +66,6 @@ jobs: strategy: matrix: include: - - python-version: "3.10" - cuda-version: "12.4" - - python-version: "3.10" - cuda-version: "12.8" - python-version: "3.10" cuda-version: "12.9" name: Build Wheel (CUDA ${{ matrix.cuda-version }}) @@ -176,7 +172,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -209,7 +205,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -242,7 +238,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -275,7 +271,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -308,7 +304,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -337,7 +333,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -398,7 +394,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -451,7 +447,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -510,7 +506,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -542,7 +538,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -574,7 +570,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -603,7 +599,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.4 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | @@ -634,7 +630,7 @@ jobs: with: path: sgl-kernel/dist/ merge-multiple: true - pattern: wheel-python3.10-cuda12.8 + pattern: wheel-python3.10-cuda12.9 - name: Install dependencies run: | diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 5a5c3ef39..ea39e239a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -239,14 +239,9 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) "-gencode=arch=compute_101a,code=sm_101a" ) endif() - -else() - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-use_fast_math" - ) endif() -if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4") set(SGL_KERNEL_ENABLE_FA3 ON) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_90a,code=sm_90a" @@ -334,14 +329,47 @@ set(SOURCES "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" ) -Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) +# Build SM90 library with fast math optimization (same namespace, different directory) +Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) -target_compile_options(common_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) -target_include_directories(common_ops PRIVATE +target_compile_definitions(common_ops_sm90_build PRIVATE + USE_FAST_MATH=1 +) +target_compile_options(common_ops_sm90_build PRIVATE + $<$:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math> +) +target_include_directories(common_ops_sm90_build PRIVATE + ${PROJECT_SOURCE_DIR}/csrc ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ) +# Set output name and separate build directory to avoid conflicts +set_target_properties(common_ops_sm90_build PROPERTIES + OUTPUT_NAME "common_ops" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90" +) + +# Build SM100+ library with precise math (same namespace, different directory) +Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) + +target_compile_definitions(common_ops_sm100_build PRIVATE + USE_FAST_MATH=0 +) +target_compile_options(common_ops_sm100_build PRIVATE + $<$:${SGL_KERNEL_CUDA_FLAGS}> +) +target_include_directories(common_ops_sm100_build PRIVATE + ${PROJECT_SOURCE_DIR}/csrc + ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha + ${repo-cutlass_SOURCE_DIR}/examples/common + ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src +) +# Set output name and separate build directory to avoid conflicts +set_target_properties(common_ops_sm100_build PROPERTIES + OUTPUT_NAME "common_ops" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm100" +) find_package(Python3 COMPONENTS Interpreter REQUIRED) execute_process( @@ -367,16 +395,26 @@ add_subdirectory( ${repo-mscclpp_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build ) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) +target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) +target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) # flash attention -target_compile_definitions(common_ops PRIVATE +target_compile_definitions(common_ops_sm90_build PRIVATE + FLASHATTENTION_DISABLE_BACKWARD + FLASHATTENTION_DISABLE_DROPOUT + FLASHATTENTION_DISABLE_UNEVEN_K +) +target_compile_definitions(common_ops_sm100_build PRIVATE FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT FLASHATTENTION_DISABLE_UNEVEN_K ) -install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel) +# Install to different subdirectories +# CMake will find the built libraries in their respective LIBRARY_OUTPUT_DIRECTORY locations +# and install them to the specified destinations +install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90) +install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100) # ============================ Optional Install ============================= # # set flash-attention sources file diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 49a97bccc..d077fc3fb 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -1,4 +1,5 @@ import ctypes +import logging import os import platform import shutil @@ -6,6 +7,183 @@ from pathlib import Path import torch +logger = logging.getLogger(__name__) + + +def _get_compute_capability(): + """Get the compute capability of the current GPU.""" + if not torch.cuda.is_available(): + return None + + # Get the current device + device = torch.cuda.current_device() + properties = torch.cuda.get_device_properties(device) + + # Return as integer (major * 10 + minor) + return properties.major * 10 + properties.minor + + +def _filter_compiled_extensions(file_list): + """Filter and prioritize compiled extensions over Python source files.""" + compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes + compiled_files = [] + other_files = [] + + for file_path in file_list: + path = Path(file_path) + # Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so) + if any( + str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions + ): + compiled_files.append(file_path) + else: + other_files.append(file_path) + + # Return compiled files first, then others + return compiled_files + other_files + + +def _load_architecture_specific_ops(): + """Load the appropriate common_ops library based on GPU architecture.""" + import importlib.util + import sys + from pathlib import Path + + compute_capability = _get_compute_capability() + logger.debug( + f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}" + ) + + # Get the directory where sgl_kernel is installed + sgl_kernel_dir = Path(__file__).parent + logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}") + + # Determine which version to load based on GPU architecture + if compute_capability == 90: + ops_subdir = "sm90" + variant_name = "SM90 (Hopper/H100 with fast math optimization)" + elif compute_capability is not None: + ops_subdir = "sm100" + variant_name = f"SM{compute_capability} (precise math for compatibility)" + else: + ops_subdir = "sm100" + variant_name = "CPU/No GPU detected (using precise math)" + + # Look for the compiled module with any valid extension + import glob + + ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*") + raw_matching_files = glob.glob(ops_pattern) + matching_files = _filter_compiled_extensions(raw_matching_files) + + logger.debug(f"[sgl_kernel] Attempting to load {variant_name}") + logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}") + logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}") + logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}") + + # Try to load from the architecture-specific directory + if matching_files: + ops_path = Path(matching_files[0]) # Use the first prioritized file + logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}") + try: + # Load the module from specific path using importlib + spec = importlib.util.spec_from_file_location("common_ops", str(ops_path)) + if spec is None: + raise ImportError(f"Could not create module spec for {ops_path}") + + common_ops = importlib.util.module_from_spec(spec) + if spec.loader is None: + raise ImportError(f"Module spec has no loader for {ops_path}") + + logger.debug(f"[sgl_kernel] Loading module from {ops_path}...") + spec.loader.exec_module(common_ops) + logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}") + logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") + return common_ops + + except Exception as e: + logger.debug( + f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}" + ) + # Continue to fallback + else: + logger.debug( + f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}" + ) + + # Try alternative directory (in case installation structure differs) + alt_pattern = str(sgl_kernel_dir / "common_ops.*") + raw_alt_files = glob.glob(alt_pattern) + alt_matching_files = _filter_compiled_extensions(raw_alt_files) + logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}") + logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}") + logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}") + + if alt_matching_files: + alt_path = Path(alt_matching_files[0]) # Use the first prioritized file + logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}") + try: + spec = importlib.util.spec_from_file_location("common_ops", str(alt_path)) + if spec is None: + raise ImportError(f"Could not create module spec for {alt_path}") + + common_ops = importlib.util.module_from_spec(spec) + if spec.loader is None: + raise ImportError(f"Module spec has no loader for {alt_path}") + + logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...") + spec.loader.exec_module(common_ops) + logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library") + logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") + return common_ops + + except Exception as e: + logger.debug( + f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}" + ) + else: + logger.debug( + f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}" + ) + + # Final attempt: try standard Python import (for backward compatibility) + logger.debug( + f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'" + ) + try: + import common_ops + + logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import") + logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") + return common_ops + except ImportError as e: + logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}") + + # All attempts failed + error_msg = f""" +[sgl_kernel] CRITICAL: Could not load any common_ops library! + +Attempted locations: +1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files} +2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files} +3. Standard Python import: common_ops - failed + +GPU Info: +- Compute capability: {compute_capability} +- Expected variant: {variant_name} + +Please ensure sgl_kernel is properly installed with: +pip install --upgrade sgl_kernel +""" + logger.debug(error_msg) + raise ImportError(error_msg) + + +# Initialize the ops library based on current GPU +logger.debug("[sgl_kernel] Initializing architecture-specific operator library...") +common_ops = _load_architecture_specific_ops() +logger.debug("[sgl_kernel] ✓ Operator library initialization complete") + # copy & modify from torch/utils/cpp_extension.py def _find_cuda_home(): @@ -42,7 +220,6 @@ if torch.version.cuda is not None: if cuda_include.exists(): ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL) -from sgl_kernel import common_ops from sgl_kernel.allreduce import * from sgl_kernel.attention import ( cutlass_mla_decode,