diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index df97582c1..89e0a5918 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -274,7 +274,6 @@ set(SOURCES "csrc/kvcacheio/transfer.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/packbit.cu" - "csrc/spatial/greenctx_stream.cu" "csrc/speculative/speculative_sampling.cu" "csrc/memory/store.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" @@ -417,6 +416,18 @@ if (SGL_KERNEL_ENABLE_FA3) target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS}) endif() +# Build spatial_ops as a separate, optional extension for green contexts +set(SPATIAL_SOURCES + "csrc/spatial/greenctx_stream.cu" + "csrc/spatial_extension.cc" +) + +Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES}) +target_compile_options(spatial_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) +target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) +install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) + + # ============================ DeepGEMM (JIT) ============================= # # Create a separate library for DeepGEMM's Python API. # This keeps its compilation isolated from the main common_ops. diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 1915f176e..d11fe5b3a 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -433,12 +433,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, " "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); - - /* - * From csrc/spatial - */ - m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); - m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index 0a5e8cea8..0366565ef 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -42,6 +42,7 @@ static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gct // This symbol is introduced in CUDA 12.5 const static auto pfn = probe_cuGreenCtxStreamCreate(); if (!pfn) { + TORCH_WARN("cuGreenCtxStreamCreate(cuda>=12.5) is not available, using fallback"); return create_greenctx_stream_fallback(gctx); } @@ -55,17 +56,12 @@ static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gct std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION)); - if (CUDA_DRIVER_VERSION < 12040) { - TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); - } - CUgreenCtx gctx[3]; CUdevResourceDesc desc[3]; CUdevResource input; CUdevResource resources[4]; - if (smA <= 0 || smB <= 0) { - TORCH_CHECK(false, "SM counts must be positive"); - } + + TORCH_CHECK(smA > 0 && smB > 0, "SM counts must be positive"); CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); diff --git a/sgl-kernel/csrc/spatial_extension.cc b/sgl-kernel/csrc/spatial_extension.cc new file mode 100644 index 000000000..27833cd70 --- /dev/null +++ b/sgl-kernel/csrc/spatial_extension.cc @@ -0,0 +1,29 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "sgl_kernel_ops.h" + +TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { + /* + * From csrc/spatial + */ + m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); + m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); +} + +REGISTER_EXTENSION(spatial_ops) diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index ee7c36541..d3099ba63 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -92,7 +92,20 @@ from sgl_kernel.sampling import ( top_p_renorm_prob, top_p_sampling_from_probs, ) -from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available + + +def create_greenctx_stream_by_value(*args, **kwargs): + from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl + + return _impl(*args, **kwargs) + + +def get_sm_available(*args, **kwargs): + from sgl_kernel.spatial import get_sm_available as _impl + + return _impl(*args, **kwargs) + + from sgl_kernel.speculative import ( build_tree_kernel_efficient, segment_packbits, diff --git a/sgl-kernel/python/sgl_kernel/spatial.py b/sgl-kernel/python/sgl_kernel/spatial.py index 25490d253..201ec7b75 100644 --- a/sgl-kernel/python/sgl_kernel/spatial.py +++ b/sgl-kernel/python/sgl_kernel/spatial.py @@ -1,6 +1,17 @@ import torch from torch.cuda.streams import ExternalStream +try: + from . import spatial_ops # triggers TORCH extension registration +except Exception as _e: + _spatial_import_error = _e +else: + _spatial_import_error = None + +_IMPORT_ERROR = ImportError( + "Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4" +) + def create_greenctx_stream_by_value( SM_a: int, SM_b: int, device_id: int = None @@ -14,11 +25,8 @@ def create_greenctx_stream_by_value( Returns: tuple[ExternalStream, ExternalStream]: The two streams. """ - if torch.version.cuda < "12.4": - raise RuntimeError( - "Green Contexts feature requires CUDA Toolkit 12.4 or newer." - ) - + if _spatial_import_error is not None: + raise _IMPORT_ERROR from _spatial_import_error if device_id is None: device_id = torch.cuda.current_device() @@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int: Returns: int: The SMs available. """ + if _spatial_import_error is not None: + raise _IMPORT_ERROR from _spatial_import_error if device_id is None: device_id = torch.cuda.current_device()