[Feature] CUDA Green Context Support (#7649)
This commit is contained in:
48
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
48
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch.cuda.streams import ExternalStream
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(
|
||||
SM_a: int, SM_b: int, device_id: int = None
|
||||
) -> tuple[ExternalStream, ExternalStream]:
|
||||
"""
|
||||
Create two streams for greenctx.
|
||||
Args:
|
||||
sm_A (int): The SM of stream A.
|
||||
sm_B (int): The weight of stream B.
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
tuple[ExternalStream, ExternalStream]: The two streams.
|
||||
"""
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
|
||||
|
||||
stream_a = ExternalStream(
|
||||
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
stream_b = ExternalStream(
|
||||
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
|
||||
return stream_a, stream_b
|
||||
|
||||
|
||||
def get_sm_available(device_id: int = None) -> int:
|
||||
"""
|
||||
Get the SMs available on the device.
|
||||
Args:
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
int: The SMs available.
|
||||
"""
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
|
||||
# Get the number of Streaming Multiprocessors (SMs)
|
||||
sm_count = device_props.multi_processor_count
|
||||
|
||||
return sm_count
|
||||
Reference in New Issue
Block a user