Unify SGL Kernel Releases (#10701)
This commit is contained in:
30
.github/workflows/pr-test.yml
vendored
30
.github/workflows/pr-test.yml
vendored
@@ -66,10 +66,6 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- python-version: "3.10"
|
|
||||||
cuda-version: "12.4"
|
|
||||||
- python-version: "3.10"
|
|
||||||
cuda-version: "12.8"
|
|
||||||
- python-version: "3.10"
|
- python-version: "3.10"
|
||||||
cuda-version: "12.9"
|
cuda-version: "12.9"
|
||||||
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
|
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
|
||||||
@@ -176,7 +172,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -209,7 +205,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -242,7 +238,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -275,7 +271,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -308,7 +304,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -337,7 +333,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -398,7 +394,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -451,7 +447,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -510,7 +506,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -542,7 +538,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -574,7 +570,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -603,7 +599,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.4
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -634,7 +630,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: sgl-kernel/dist/
|
path: sgl-kernel/dist/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
pattern: wheel-python3.10-cuda12.8
|
pattern: wheel-python3.10-cuda12.9
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -239,14 +239,9 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
|||||||
"-gencode=arch=compute_101a,code=sm_101a"
|
"-gencode=arch=compute_101a,code=sm_101a"
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
|
||||||
"-use_fast_math"
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
|
||||||
set(SGL_KERNEL_ENABLE_FA3 ON)
|
set(SGL_KERNEL_ENABLE_FA3 ON)
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-gencode=arch=compute_90a,code=sm_90a"
|
"-gencode=arch=compute_90a,code=sm_90a"
|
||||||
@@ -334,14 +329,47 @@ set(SOURCES
|
|||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
# Build SM90 library with fast math optimization (same namespace, different directory)
|
||||||
|
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||||
|
|
||||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
target_compile_definitions(common_ops_sm90_build PRIVATE
|
||||||
target_include_directories(common_ops PRIVATE
|
USE_FAST_MATH=1
|
||||||
|
)
|
||||||
|
target_compile_options(common_ops_sm90_build PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
|
||||||
|
)
|
||||||
|
target_include_directories(common_ops_sm90_build PRIVATE
|
||||||
|
${PROJECT_SOURCE_DIR}/csrc
|
||||||
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
||||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||||
)
|
)
|
||||||
|
# Set output name and separate build directory to avoid conflicts
|
||||||
|
set_target_properties(common_ops_sm90_build PROPERTIES
|
||||||
|
OUTPUT_NAME "common_ops"
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build SM100+ library with precise math (same namespace, different directory)
|
||||||
|
Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||||
|
|
||||||
|
target_compile_definitions(common_ops_sm100_build PRIVATE
|
||||||
|
USE_FAST_MATH=0
|
||||||
|
)
|
||||||
|
target_compile_options(common_ops_sm100_build PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>
|
||||||
|
)
|
||||||
|
target_include_directories(common_ops_sm100_build PRIVATE
|
||||||
|
${PROJECT_SOURCE_DIR}/csrc
|
||||||
|
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
||||||
|
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||||
|
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||||
|
)
|
||||||
|
# Set output name and separate build directory to avoid conflicts
|
||||||
|
set_target_properties(common_ops_sm100_build PROPERTIES
|
||||||
|
OUTPUT_NAME "common_ops"
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm100"
|
||||||
|
)
|
||||||
|
|
||||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
@@ -367,16 +395,26 @@ add_subdirectory(
|
|||||||
${repo-mscclpp_SOURCE_DIR}
|
${repo-mscclpp_SOURCE_DIR}
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
|
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
|
||||||
)
|
)
|
||||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||||
|
target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
target_compile_definitions(common_ops PRIVATE
|
target_compile_definitions(common_ops_sm90_build PRIVATE
|
||||||
|
FLASHATTENTION_DISABLE_BACKWARD
|
||||||
|
FLASHATTENTION_DISABLE_DROPOUT
|
||||||
|
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||||
|
)
|
||||||
|
target_compile_definitions(common_ops_sm100_build PRIVATE
|
||||||
FLASHATTENTION_DISABLE_BACKWARD
|
FLASHATTENTION_DISABLE_BACKWARD
|
||||||
FLASHATTENTION_DISABLE_DROPOUT
|
FLASHATTENTION_DISABLE_DROPOUT
|
||||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||||
)
|
)
|
||||||
|
|
||||||
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
|
# Install to different subdirectories
|
||||||
|
# CMake will find the built libraries in their respective LIBRARY_OUTPUT_DIRECTORY locations
|
||||||
|
# and install them to the specified destinations
|
||||||
|
install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90)
|
||||||
|
install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100)
|
||||||
|
|
||||||
# ============================ Optional Install ============================= #
|
# ============================ Optional Install ============================= #
|
||||||
# set flash-attention sources file
|
# set flash-attention sources file
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
@@ -6,6 +7,183 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_compute_capability():
|
||||||
|
"""Get the compute capability of the current GPU."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the current device
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
properties = torch.cuda.get_device_properties(device)
|
||||||
|
|
||||||
|
# Return as integer (major * 10 + minor)
|
||||||
|
return properties.major * 10 + properties.minor
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_compiled_extensions(file_list):
|
||||||
|
"""Filter and prioritize compiled extensions over Python source files."""
|
||||||
|
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
|
||||||
|
compiled_files = []
|
||||||
|
other_files = []
|
||||||
|
|
||||||
|
for file_path in file_list:
|
||||||
|
path = Path(file_path)
|
||||||
|
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
|
||||||
|
if any(
|
||||||
|
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
|
||||||
|
):
|
||||||
|
compiled_files.append(file_path)
|
||||||
|
else:
|
||||||
|
other_files.append(file_path)
|
||||||
|
|
||||||
|
# Return compiled files first, then others
|
||||||
|
return compiled_files + other_files
|
||||||
|
|
||||||
|
|
||||||
|
def _load_architecture_specific_ops():
|
||||||
|
"""Load the appropriate common_ops library based on GPU architecture."""
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
compute_capability = _get_compute_capability()
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the directory where sgl_kernel is installed
|
||||||
|
sgl_kernel_dir = Path(__file__).parent
|
||||||
|
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
|
||||||
|
|
||||||
|
# Determine which version to load based on GPU architecture
|
||||||
|
if compute_capability == 90:
|
||||||
|
ops_subdir = "sm90"
|
||||||
|
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
|
||||||
|
elif compute_capability is not None:
|
||||||
|
ops_subdir = "sm100"
|
||||||
|
variant_name = f"SM{compute_capability} (precise math for compatibility)"
|
||||||
|
else:
|
||||||
|
ops_subdir = "sm100"
|
||||||
|
variant_name = "CPU/No GPU detected (using precise math)"
|
||||||
|
|
||||||
|
# Look for the compiled module with any valid extension
|
||||||
|
import glob
|
||||||
|
|
||||||
|
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
|
||||||
|
raw_matching_files = glob.glob(ops_pattern)
|
||||||
|
matching_files = _filter_compiled_extensions(raw_matching_files)
|
||||||
|
|
||||||
|
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
|
||||||
|
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
|
||||||
|
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
|
||||||
|
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
|
||||||
|
|
||||||
|
# Try to load from the architecture-specific directory
|
||||||
|
if matching_files:
|
||||||
|
ops_path = Path(matching_files[0]) # Use the first prioritized file
|
||||||
|
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
|
||||||
|
try:
|
||||||
|
# Load the module from specific path using importlib
|
||||||
|
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Could not create module spec for {ops_path}")
|
||||||
|
|
||||||
|
common_ops = importlib.util.module_from_spec(spec)
|
||||||
|
if spec.loader is None:
|
||||||
|
raise ImportError(f"Module spec has no loader for {ops_path}")
|
||||||
|
|
||||||
|
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
|
||||||
|
spec.loader.exec_module(common_ops)
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
|
||||||
|
return common_ops
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
# Continue to fallback
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try alternative directory (in case installation structure differs)
|
||||||
|
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
|
||||||
|
raw_alt_files = glob.glob(alt_pattern)
|
||||||
|
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
|
||||||
|
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
|
||||||
|
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
|
||||||
|
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
|
||||||
|
|
||||||
|
if alt_matching_files:
|
||||||
|
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
|
||||||
|
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Could not create module spec for {alt_path}")
|
||||||
|
|
||||||
|
common_ops = importlib.util.module_from_spec(spec)
|
||||||
|
if spec.loader is None:
|
||||||
|
raise ImportError(f"Module spec has no loader for {alt_path}")
|
||||||
|
|
||||||
|
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
|
||||||
|
spec.loader.exec_module(common_ops)
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
|
||||||
|
return common_ops
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final attempt: try standard Python import (for backward compatibility)
|
||||||
|
logger.debug(
|
||||||
|
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import common_ops
|
||||||
|
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
|
||||||
|
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
|
||||||
|
return common_ops
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
|
||||||
|
|
||||||
|
# All attempts failed
|
||||||
|
error_msg = f"""
|
||||||
|
[sgl_kernel] CRITICAL: Could not load any common_ops library!
|
||||||
|
|
||||||
|
Attempted locations:
|
||||||
|
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
|
||||||
|
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
|
||||||
|
3. Standard Python import: common_ops - failed
|
||||||
|
|
||||||
|
GPU Info:
|
||||||
|
- Compute capability: {compute_capability}
|
||||||
|
- Expected variant: {variant_name}
|
||||||
|
|
||||||
|
Please ensure sgl_kernel is properly installed with:
|
||||||
|
pip install --upgrade sgl_kernel
|
||||||
|
"""
|
||||||
|
logger.debug(error_msg)
|
||||||
|
raise ImportError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the ops library based on current GPU
|
||||||
|
logger.debug("[sgl_kernel] Initializing architecture-specific operator library...")
|
||||||
|
common_ops = _load_architecture_specific_ops()
|
||||||
|
logger.debug("[sgl_kernel] ✓ Operator library initialization complete")
|
||||||
|
|
||||||
|
|
||||||
# copy & modify from torch/utils/cpp_extension.py
|
# copy & modify from torch/utils/cpp_extension.py
|
||||||
def _find_cuda_home():
|
def _find_cuda_home():
|
||||||
@@ -42,7 +220,6 @@ if torch.version.cuda is not None:
|
|||||||
if cuda_include.exists():
|
if cuda_include.exists():
|
||||||
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
|
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
|
||||||
|
|
||||||
from sgl_kernel import common_ops
|
|
||||||
from sgl_kernel.allreduce import *
|
from sgl_kernel.allreduce import *
|
||||||
from sgl_kernel.attention import (
|
from sgl_kernel.attention import (
|
||||||
cutlass_mla_decode,
|
cutlass_mla_decode,
|
||||||
|
|||||||
Reference in New Issue
Block a user