From 719b29f218a09642193c4bda2a7ffa32829d5604 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Fri, 18 Jul 2025 17:45:03 +0800 Subject: [PATCH] feat: enchance green context stream creation robust with backward compatibility (#8136) --- sgl-kernel/csrc/spatial/greenctx_stream.cu | 59 ++++++++++++---------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index 8c2e6d813..9d7a44a1a 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -7,17 +7,15 @@ #include "cuda_utils.h" #include "greenctx_stream.h" -std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { +static std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { CUstream streamA, streamB; CUcontext ctx; - // Stream A CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0])); CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING)); CUDA_DRV(cuCtxPopCurrent(nullptr)); - // Stream B CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1])); CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING)); @@ -26,18 +24,31 @@ std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { return {(int64_t)streamA, (int64_t)streamB}; } -#if CUDA_VERSION >= 12050 -std::vector create_greenctx_stream_direct(CUgreenCtx gctx[2]) { - CUstream streamA; - CUstream streamB; +typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int); - CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); - CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); +static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) { + static PFN_cuGreenCtxStreamCreate pfn = nullptr; + static std::once_flag pfn_probed_flag; - std::vector vec = {(int64_t)streamA, (int64_t)streamB}; - return vec; + // detect compatibility in runtime + std::call_once(pfn_probed_flag, []() { + cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast(&pfn), 0, 0, nullptr); + }); + + if (!pfn) { // fallback if not compatible + return create_greenctx_stream_fallback(gctx); + } + + CUstream streamA, streamB; + CUDA_DRV(pfn(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRV(pfn(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); + + return {(int64_t)streamA, (int64_t)streamB}; +} + +inline void destroy_green_context(int64_t h) { + if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast(h))); } -#endif std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); @@ -46,42 +57,38 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i CUdevResourceDesc desc[3]; CUdevResource input; CUdevResource resources[4]; - unsigned int nbGroups = 1; - if (smA <= 0 || smB <= 0) { TORCH_CHECK(false, "SM counts must be positive"); } CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); - unsigned int minCount = (unsigned int)(smA + smB); - unsigned int minCountA = (unsigned int)(smA); + + const unsigned minCount = smA + smB; + const unsigned minCountA = smA; TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); + unsigned nbGroups = 1; CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); + nbGroups = 1; CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); - int smCountA = resources[0].sm.smCount; - int smCountB = resources[1].sm.smCount; - std::vector stream_handles; + const int smCountA = resources[0].sm.smCount; + const int smCountB = resources[1].sm.smCount; -#if CUDA_VERSION >= 12050 - stream_handles = create_greenctx_stream_direct(gctx); -#else - stream_handles = create_greenctx_stream_fallback(gctx); -#endif + std::vector streams = create_greenctx_stream_direct_dynamic(gctx); CUDA_DRV(cuGreenCtxDestroy(gctx[2])); std::vector vec = { - stream_handles[0], // streamA - stream_handles[1], // streamB + streams[0], // streamA + streams[1], // streamB (int64_t)smCountA, (int64_t)smCountB};