Unify SGL Kernel Releases (#10701)

This commit is contained in:
Kangyan-Zhou
2025-09-28 19:48:28 -07:00
committed by GitHub
parent 2572886367
commit 0c9174108a
3 changed files with 241 additions and 30 deletions

View File

@@ -1,4 +1,5 @@
import ctypes
import logging
import os
import platform
import shutil
@@ -6,6 +7,183 @@ from pathlib import Path
import torch
logger = logging.getLogger(__name__)
def _get_compute_capability():
"""Get the compute capability of the current GPU."""
if not torch.cuda.is_available():
return None
# Get the current device
device = torch.cuda.current_device()
properties = torch.cuda.get_device_properties(device)
# Return as integer (major * 10 + minor)
return properties.major * 10 + properties.minor
def _filter_compiled_extensions(file_list):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
compiled_files = []
other_files = []
for file_path in file_list:
path = Path(file_path)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if any(
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
):
compiled_files.append(file_path)
else:
other_files.append(file_path)
# Return compiled files first, then others
return compiled_files + other_files
def _load_architecture_specific_ops():
"""Load the appropriate common_ops library based on GPU architecture."""
import importlib.util
import sys
from pathlib import Path
compute_capability = _get_compute_capability()
logger.debug(
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir = Path(__file__).parent
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
# Determine which version to load based on GPU architecture
if compute_capability == 90:
ops_subdir = "sm90"
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
elif compute_capability is not None:
ops_subdir = "sm100"
variant_name = f"SM{compute_capability} (precise math for compatibility)"
else:
ops_subdir = "sm100"
variant_name = "CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
import glob
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
raw_matching_files = glob.glob(ops_pattern)
matching_files = _filter_compiled_extensions(raw_matching_files)
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
# Try to load from the architecture-specific directory
if matching_files:
ops_path = Path(matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
try:
# Load the module from specific path using importlib
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
if spec is None:
raise ImportError(f"Could not create module spec for {ops_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {ops_path}")
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
logger.debug(
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
)
# Continue to fallback
else:
logger.debug(
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
)
# Try alternative directory (in case installation structure differs)
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
raw_alt_files = glob.glob(alt_pattern)
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
if alt_matching_files:
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
try:
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
if spec is None:
raise ImportError(f"Could not create module spec for {alt_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {alt_path}")
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
logger.debug(
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
)
else:
logger.debug(
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
)
# Final attempt: try standard Python import (for backward compatibility)
logger.debug(
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try:
import common_ops
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except ImportError as e:
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
# All attempts failed
error_msg = f"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability: {compute_capability}
- Expected variant: {variant_name}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
"""
logger.debug(error_msg)
raise ImportError(error_msg)
# Initialize the ops library based on current GPU
logger.debug("[sgl_kernel] Initializing architecture-specific operator library...")
common_ops = _load_architecture_specific_ops()
logger.debug("[sgl_kernel] ✓ Operator library initialization complete")
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
@@ -42,7 +220,6 @@ if torch.version.cuda is not None:
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
from sgl_kernel.attention import (
cutlass_mla_decode,