Runtime check CUDA driver version to avoid unresolved green context symbols (#9021)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#define CUDA_RT(call) \
|
||||
do { \
|
||||
cudaError_t _status = (call); \
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
// Documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "greenctx_stream.h"
|
||||
|
||||
#if CUDA_VERSION >= 12040
|
||||
static int CUDA_DRIVER_VERSION;
|
||||
|
||||
using PFN_cuGreenCtxStreamCreate = CUresult(CUDAAPI*)(CUstream*, CUgreenCtx, unsigned int, int);
|
||||
|
||||
auto probe_cuGreenCtxStreamCreate() -> PFN_cuGreenCtxStreamCreate {
|
||||
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
|
||||
CUDA_DRV(cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), CUDA_DRIVER_VERSION, 0, nullptr));
|
||||
return pfn;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
|
||||
CUstream streamA, streamB;
|
||||
@@ -26,18 +33,15 @@ static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2])
|
||||
return {(int64_t)streamA, (int64_t)streamB};
|
||||
}
|
||||
|
||||
typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int);
|
||||
inline void destroy_green_context(CUgreenCtx gctx) {
|
||||
if (!gctx) return;
|
||||
CUDA_DRV(cuGreenCtxDestroy(gctx));
|
||||
}
|
||||
|
||||
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
|
||||
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
|
||||
static std::once_flag pfn_probed_flag;
|
||||
|
||||
// detect compatibility in runtime
|
||||
std::call_once(pfn_probed_flag, []() {
|
||||
cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
|
||||
});
|
||||
|
||||
if (!pfn) { // fallback if not compatible
|
||||
// This symbol is introduced in CUDA 12.5
|
||||
const static auto pfn = probe_cuGreenCtxStreamCreate();
|
||||
if (!pfn) {
|
||||
return create_greenctx_stream_fallback(gctx);
|
||||
}
|
||||
|
||||
@@ -48,12 +52,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
|
||||
return {(int64_t)streamA, (int64_t)streamB};
|
||||
}
|
||||
|
||||
inline void destroy_green_context(int64_t h) {
|
||||
if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
|
||||
}
|
||||
|
||||
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
||||
TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
|
||||
CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
|
||||
|
||||
if (CUDA_DRIVER_VERSION < 12040) {
|
||||
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
|
||||
}
|
||||
|
||||
CUgreenCtx gctx[3];
|
||||
CUdevResourceDesc desc[3];
|
||||
@@ -65,8 +69,8 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
||||
|
||||
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
const unsigned minCount = smA + smB;
|
||||
const unsigned minCountA = smA;
|
||||
const unsigned minCount = static_cast<unsigned>(smA + smB);
|
||||
const unsigned minCountA = static_cast<unsigned>(smA);
|
||||
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
|
||||
|
||||
unsigned nbGroups = 1;
|
||||
@@ -86,7 +90,7 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
||||
|
||||
std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
|
||||
|
||||
CUDA_DRV(cuGreenCtxDestroy(gctx[2]));
|
||||
destroy_green_context(gctx[2]);
|
||||
|
||||
std::vector<int64_t> vec = {
|
||||
streams[0], // streamA
|
||||
@@ -96,18 +100,3 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " +
|
||||
std::to_string(CUDA_VERSION));
|
||||
|
||||
// This is a stub function that should never be reached
|
||||
// Return empty vector to satisfy return type requirement
|
||||
return {};
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user