Runtime check CUDA driver version to avoid unresolved green context symbols (#9021)
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#define CUDA_RT(call) \
|
#define CUDA_RT(call) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t _status = (call); \
|
cudaError_t _status = (call); \
|
||||||
|
|||||||
@@ -1,13 +1,20 @@
|
|||||||
|
// Documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "greenctx_stream.h"
|
#include "greenctx_stream.h"
|
||||||
|
|
||||||
#if CUDA_VERSION >= 12040
|
static int CUDA_DRIVER_VERSION;
|
||||||
|
|
||||||
|
using PFN_cuGreenCtxStreamCreate = CUresult(CUDAAPI*)(CUstream*, CUgreenCtx, unsigned int, int);
|
||||||
|
|
||||||
|
auto probe_cuGreenCtxStreamCreate() -> PFN_cuGreenCtxStreamCreate {
|
||||||
|
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
|
||||||
|
CUDA_DRV(cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), CUDA_DRIVER_VERSION, 0, nullptr));
|
||||||
|
return pfn;
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
|
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
|
||||||
CUstream streamA, streamB;
|
CUstream streamA, streamB;
|
||||||
@@ -26,18 +33,15 @@ static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2])
|
|||||||
return {(int64_t)streamA, (int64_t)streamB};
|
return {(int64_t)streamA, (int64_t)streamB};
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int);
|
inline void destroy_green_context(CUgreenCtx gctx) {
|
||||||
|
if (!gctx) return;
|
||||||
|
CUDA_DRV(cuGreenCtxDestroy(gctx));
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
|
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
|
||||||
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
|
// This symbol is introduced in CUDA 12.5
|
||||||
static std::once_flag pfn_probed_flag;
|
const static auto pfn = probe_cuGreenCtxStreamCreate();
|
||||||
|
if (!pfn) {
|
||||||
// detect compatibility in runtime
|
|
||||||
std::call_once(pfn_probed_flag, []() {
|
|
||||||
cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!pfn) { // fallback if not compatible
|
|
||||||
return create_greenctx_stream_fallback(gctx);
|
return create_greenctx_stream_fallback(gctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,12 +52,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
|
|||||||
return {(int64_t)streamA, (int64_t)streamB};
|
return {(int64_t)streamA, (int64_t)streamB};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void destroy_green_context(int64_t h) {
|
|
||||||
if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
||||||
TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
|
CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
|
||||||
|
|
||||||
|
if (CUDA_DRIVER_VERSION < 12040) {
|
||||||
|
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
|
||||||
|
}
|
||||||
|
|
||||||
CUgreenCtx gctx[3];
|
CUgreenCtx gctx[3];
|
||||||
CUdevResourceDesc desc[3];
|
CUdevResourceDesc desc[3];
|
||||||
@@ -65,8 +69,8 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
|||||||
|
|
||||||
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
|
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
|
||||||
|
|
||||||
const unsigned minCount = smA + smB;
|
const unsigned minCount = static_cast<unsigned>(smA + smB);
|
||||||
const unsigned minCountA = smA;
|
const unsigned minCountA = static_cast<unsigned>(smA);
|
||||||
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
|
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
|
||||||
|
|
||||||
unsigned nbGroups = 1;
|
unsigned nbGroups = 1;
|
||||||
@@ -86,7 +90,7 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
|||||||
|
|
||||||
std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
|
std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
|
||||||
|
|
||||||
CUDA_DRV(cuGreenCtxDestroy(gctx[2]));
|
destroy_green_context(gctx[2]);
|
||||||
|
|
||||||
std::vector<int64_t> vec = {
|
std::vector<int64_t> vec = {
|
||||||
streams[0], // streamA
|
streams[0], // streamA
|
||||||
@@ -96,18 +100,3 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
|||||||
|
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
"Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " +
|
|
||||||
std::to_string(CUDA_VERSION));
|
|
||||||
|
|
||||||
// This is a stub function that should never be reached
|
|
||||||
// Return empty vector to satisfy return type requirement
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import subprocess
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
Reference in New Issue
Block a user