Optional extension for green context (#9231)

This commit is contained in:
Liangsheng Yin
2025-08-15 21:33:52 +08:00
committed by GitHub
parent c186feed7f
commit 0c8594e67d
6 changed files with 73 additions and 20 deletions

View File

@@ -92,7 +92,20 @@ from sgl_kernel.sampling import (
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available
def create_greenctx_stream_by_value(*args, **kwargs):
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
return _impl(*args, **kwargs)
def get_sm_available(*args, **kwargs):
from sgl_kernel.spatial import get_sm_available as _impl
return _impl(*args, **kwargs)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
segment_packbits,

View File

@@ -1,6 +1,17 @@
import torch
from torch.cuda.streams import ExternalStream
try:
from . import spatial_ops # triggers TORCH extension registration
except Exception as _e:
_spatial_import_error = _e
else:
_spatial_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
@@ -14,11 +25,8 @@ def create_greenctx_stream_by_value(
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if torch.version.cuda < "12.4":
raise RuntimeError(
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
)
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
@@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int:
Returns:
int: The SMs available.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()