From 6dc4af49377d25fb9d745c5dd14f13a04f9ffbdd Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Wed, 16 Jul 2025 22:08:46 +0800 Subject: [PATCH] fix greenctx stream compability (#8090) --- sgl-kernel/csrc/spatial/greenctx_stream.cu | 65 ++++++++++++++++------ 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index b549aea5f..8c2e6d813 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -7,52 +7,83 @@ #include "cuda_utils.h" #include "greenctx_stream.h" +std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { + CUstream streamA, streamB; + CUcontext ctx; + + // Stream A + CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0])); + CUDA_DRV(cuCtxPushCurrent(ctx)); + CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING)); + CUDA_DRV(cuCtxPopCurrent(nullptr)); + + // Stream B + CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1])); + CUDA_DRV(cuCtxPushCurrent(ctx)); + CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING)); + CUDA_DRV(cuCtxPopCurrent(nullptr)); + + return {(int64_t)streamA, (int64_t)streamB}; +} + +#if CUDA_VERSION >= 12050 +std::vector create_greenctx_stream_direct(CUgreenCtx gctx[2]) { + CUstream streamA; + CUstream streamB; + + CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); + + std::vector vec = {(int64_t)streamA, (int64_t)streamB}; + return vec; +} +#endif + std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { + TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); + CUgreenCtx gctx[3]; CUdevResourceDesc desc[3]; CUdevResource input; CUdevResource resources[4]; - CUstream streamA; - CUstream streamB; - unsigned int nbGroups = 1; if (smA <= 0 || smB <= 0) { TORCH_CHECK(false, "SM counts must be positive"); } - // Initialize device - CUDA_RT(cudaInitDevice(device, 0, 0)); - - // Query input SMs CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); - // We want 3/4 the device for our green context unsigned int minCount = (unsigned int)(smA + smB); unsigned int minCountA = (unsigned int)(smA); - TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); - // Split resources CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); - CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); - CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); - - CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); - CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); - int smCountA = resources[0].sm.smCount; int smCountB = resources[1].sm.smCount; + std::vector stream_handles; + +#if CUDA_VERSION >= 12050 + stream_handles = create_greenctx_stream_direct(gctx); +#else + stream_handles = create_greenctx_stream_fallback(gctx); +#endif + CUDA_DRV(cuGreenCtxDestroy(gctx[2])); - std::vector vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB}; + std::vector vec = { + stream_handles[0], // streamA + stream_handles[1], // streamB + (int64_t)smCountA, + (int64_t)smCountB}; + return vec; }