diff --git a/sgl-kernel/csrc/spatial/cuda_utils.h b/sgl-kernel/csrc/spatial/cuda_utils.h index 126ed05d8..65132975f 100644 --- a/sgl-kernel/csrc/spatial/cuda_utils.h +++ b/sgl-kernel/csrc/spatial/cuda_utils.h @@ -1,6 +1,8 @@ #include #include +#include + #define CUDA_RT(call) \ do { \ cudaError_t _status = (call); \ diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index cf3e7da65..0a5e8cea8 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -1,13 +1,20 @@ +// Documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html #include #include -#include -#include #include "cuda_utils.h" #include "greenctx_stream.h" -#if CUDA_VERSION >= 12040 +static int CUDA_DRIVER_VERSION; + +using PFN_cuGreenCtxStreamCreate = CUresult(CUDAAPI*)(CUstream*, CUgreenCtx, unsigned int, int); + +auto probe_cuGreenCtxStreamCreate() -> PFN_cuGreenCtxStreamCreate { + static PFN_cuGreenCtxStreamCreate pfn = nullptr; + CUDA_DRV(cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast(&pfn), CUDA_DRIVER_VERSION, 0, nullptr)); + return pfn; +} static std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { CUstream streamA, streamB; @@ -26,18 +33,15 @@ static std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) return {(int64_t)streamA, (int64_t)streamB}; } -typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int); +inline void destroy_green_context(CUgreenCtx gctx) { + if (!gctx) return; + CUDA_DRV(cuGreenCtxDestroy(gctx)); +} static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) { - static PFN_cuGreenCtxStreamCreate pfn = nullptr; - static std::once_flag pfn_probed_flag; - - // detect compatibility in runtime - std::call_once(pfn_probed_flag, []() { - cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast(&pfn), 0, 0, nullptr); - }); - - if (!pfn) { // fallback if not compatible + // This symbol is introduced in CUDA 12.5 + const static auto pfn = probe_cuGreenCtxStreamCreate(); + if (!pfn) { return create_greenctx_stream_fallback(gctx); } @@ -48,12 +52,12 @@ static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gct return {(int64_t)streamA, (int64_t)streamB}; } -inline void destroy_green_context(int64_t h) { - if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast(h))); -} - std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { - TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); + CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION)); + + if (CUDA_DRIVER_VERSION < 12040) { + TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); + } CUgreenCtx gctx[3]; CUdevResourceDesc desc[3]; @@ -65,8 +69,8 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); - const unsigned minCount = smA + smB; - const unsigned minCountA = smA; + const unsigned minCount = static_cast(smA + smB); + const unsigned minCountA = static_cast(smA); TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); unsigned nbGroups = 1; @@ -86,7 +90,7 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i std::vector streams = create_greenctx_stream_direct_dynamic(gctx); - CUDA_DRV(cuGreenCtxDestroy(gctx[2])); + destroy_green_context(gctx[2]); std::vector vec = { streams[0], // streamA @@ -96,18 +100,3 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i return vec; } - -#else - -std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { - TORCH_CHECK( - false, - "Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " + - std::to_string(CUDA_VERSION)); - - // This is a stub function that should never be reached - // Return empty vector to satisfy return type requirement - return {}; -} - -#endif diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index 2960d3419..5fcbd6a9c 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -14,6 +14,7 @@ # ============================================================================== import functools +import subprocess from typing import Dict, Tuple import torch