diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 57c444e76..89a298c34 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -246,6 +246,7 @@ set(SOURCES "csrc/moe/ep_moe_silu_and_mul_kernel.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/packbit.cu" + "csrc/spatial/greenctx_stream.cu" "csrc/speculative/speculative_sampling.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/kvcacheio/transfer.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index f5eb9bfe5..070fe4bd2 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -401,6 +401,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, " "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); + + /* + * From csrc/spatial + */ + m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); + m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/spatial/cuda_utils.h b/sgl-kernel/csrc/spatial/cuda_utils.h new file mode 100644 index 000000000..126ed05d8 --- /dev/null +++ b/sgl-kernel/csrc/spatial/cuda_utils.h @@ -0,0 +1,44 @@ +#include +#include + +#define CUDA_RT(call) \ + do { \ + cudaError_t _status = (call); \ + if (_status != cudaSuccess) { \ + std::cerr << "ERROR: CUDA RT call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \ + << " failed with " << cudaGetErrorString(_status) << std::endl; \ + TORCH_CHECK( \ + false, \ + c10::str( \ + "ERROR: CUDA RT call \"", \ + #call, \ + "\" in line ", \ + __LINE__, \ + " of file ", \ + __FILE__, \ + " failed with ", \ + cudaGetErrorString(_status))); \ + } \ + } while (0) + +#define CUDA_DRV(call) \ + do { \ + CUresult _status = (call); \ + if (_status != CUDA_SUCCESS) { \ + const char* err_str; \ + cuGetErrorString(_status, &err_str); \ + std::cerr << "ERROR: CUDA DRV call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \ + << " failed with " << err_str << std::endl; \ + TORCH_CHECK( \ + false, \ + c10::str( \ + "ERROR: CUDA DRV call \"", \ + #call, \ + "\" in line ", \ + __LINE__, \ + " of file ", \ + __FILE__, \ + " failed with ", \ + err_str)); \ + } \ + } while (0) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu new file mode 100644 index 000000000..b549aea5f --- /dev/null +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -0,0 +1,58 @@ +#include + +#include +#include +#include + +#include "cuda_utils.h" +#include "greenctx_stream.h" + +std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { + CUgreenCtx gctx[3]; + CUdevResourceDesc desc[3]; + CUdevResource input; + CUdevResource resources[4]; + CUstream streamA; + CUstream streamB; + + unsigned int nbGroups = 1; + + if (smA <= 0 || smB <= 0) { + TORCH_CHECK(false, "SM counts must be positive"); + } + + // Initialize device + CUDA_RT(cudaInitDevice(device, 0, 0)); + + // Query input SMs + CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); + // We want 3/4 the device for our green context + unsigned int minCount = (unsigned int)(smA + smB); + unsigned int minCountA = (unsigned int)(smA); + + TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); + + // Split resources + CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); + + CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); + CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); + CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); + CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); + + CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); + CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); + CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); + CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); + + CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); + + int smCountA = resources[0].sm.smCount; + int smCountB = resources[1].sm.smCount; + + CUDA_DRV(cuGreenCtxDestroy(gctx[2])); + + std::vector vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB}; + return vec; +} diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.h b/sgl-kernel/csrc/spatial/greenctx_stream.h new file mode 100644 index 000000000..2577e9f29 --- /dev/null +++ b/sgl-kernel/csrc/spatial/greenctx_stream.h @@ -0,0 +1,2 @@ +#include +std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 4d4990041..df06bd3cd 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -661,3 +661,8 @@ void qserve_w4a8_per_group_gemm( const torch::Tensor& _wscales, const torch::Tensor& _ascales, torch::Tensor& _out_feats); + +/* + * From csrc/spatial + */ +std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 27666f0f6..5cecfc3c0 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -81,6 +81,7 @@ from sgl_kernel.sampling import ( top_p_renorm_prob, top_p_sampling_from_probs, ) +from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available from sgl_kernel.speculative import ( build_tree_kernel_efficient, segment_packbits, diff --git a/sgl-kernel/python/sgl_kernel/spatial.py b/sgl-kernel/python/sgl_kernel/spatial.py new file mode 100644 index 000000000..8fe2a3dd7 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/spatial.py @@ -0,0 +1,48 @@ +import torch +from torch.cuda.streams import ExternalStream + + +def create_greenctx_stream_by_value( + SM_a: int, SM_b: int, device_id: int = None +) -> tuple[ExternalStream, ExternalStream]: + """ + Create two streams for greenctx. + Args: + sm_A (int): The SM of stream A. + sm_B (int): The weight of stream B. + device_id (int): The device id. + Returns: + tuple[ExternalStream, ExternalStream]: The two streams. + """ + if device_id is None: + device_id = torch.cuda.current_device() + + res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id) + + stream_a = ExternalStream( + stream_ptr=res[0], device=torch.device(f"cuda:{device_id}") + ) + stream_b = ExternalStream( + stream_ptr=res[1], device=torch.device(f"cuda:{device_id}") + ) + + return stream_a, stream_b + + +def get_sm_available(device_id: int = None) -> int: + """ + Get the SMs available on the device. + Args: + device_id (int): The device id. + Returns: + int: The SMs available. + """ + if device_id is None: + device_id = torch.cuda.current_device() + + device_props = torch.cuda.get_device_properties(device_id) + + # Get the number of Streaming Multiprocessors (SMs) + sm_count = device_props.multi_processor_count + + return sm_count diff --git a/sgl-kernel/tests/spatial/test_greenctx_stream.py b/sgl-kernel/tests/spatial/test_greenctx_stream.py new file mode 100644 index 000000000..c57bc3360 --- /dev/null +++ b/sgl-kernel/tests/spatial/test_greenctx_stream.py @@ -0,0 +1,25 @@ +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import create_greenctx_stream_by_value, get_sm_available + + +def test_green_ctx(): + A = torch.randn(5120, 5120).cuda() + B = torch.randn(5120, 5120).cuda() + C = torch.matmul(A, B) + sm_counts = get_sm_available(0) + stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0) + with torch.cuda.stream(stream_group[0]): + for _ in range(100): + result_0 = torch.matmul(A, B) + with torch.cuda.stream(stream_group[1]): + for _ in range(100): + result_1 = torch.matmul(A, B) + torch.cuda.synchronize() + assert torch.allclose(result_0, C) + assert torch.allclose(result_1, C) + + +if __name__ == "__main__": + pytest.main([__file__])