Optional extension for green context (#9231)
This commit is contained in:
@@ -1,6 +1,17 @@
|
||||
import torch
|
||||
from torch.cuda.streams import ExternalStream
|
||||
|
||||
try:
|
||||
from . import spatial_ops # triggers TORCH extension registration
|
||||
except Exception as _e:
|
||||
_spatial_import_error = _e
|
||||
else:
|
||||
_spatial_import_error = None
|
||||
|
||||
_IMPORT_ERROR = ImportError(
|
||||
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
|
||||
)
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(
|
||||
SM_a: int, SM_b: int, device_id: int = None
|
||||
@@ -14,11 +25,8 @@ def create_greenctx_stream_by_value(
|
||||
Returns:
|
||||
tuple[ExternalStream, ExternalStream]: The two streams.
|
||||
"""
|
||||
if torch.version.cuda < "12.4":
|
||||
raise RuntimeError(
|
||||
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
|
||||
)
|
||||
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
@@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int:
|
||||
Returns:
|
||||
int: The SMs available.
|
||||
"""
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user