[Feature] CUDA Green Context Support (#7649)
This commit is contained in:
@@ -246,6 +246,7 @@ set(SOURCES
|
|||||||
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
||||||
"csrc/speculative/eagle_utils.cu"
|
"csrc/speculative/eagle_utils.cu"
|
||||||
"csrc/speculative/packbit.cu"
|
"csrc/speculative/packbit.cu"
|
||||||
|
"csrc/spatial/greenctx_stream.cu"
|
||||||
"csrc/speculative/speculative_sampling.cu"
|
"csrc/speculative/speculative_sampling.cu"
|
||||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||||
"csrc/kvcacheio/transfer.cu"
|
"csrc/kvcacheio/transfer.cu"
|
||||||
|
|||||||
@@ -401,6 +401,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
||||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/spatial
|
||||||
|
*/
|
||||||
|
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
|
||||||
|
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
44
sgl-kernel/csrc/spatial/cuda_utils.h
Normal file
44
sgl-kernel/csrc/spatial/cuda_utils.h
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#define CUDA_RT(call) \
|
||||||
|
do { \
|
||||||
|
cudaError_t _status = (call); \
|
||||||
|
if (_status != cudaSuccess) { \
|
||||||
|
std::cerr << "ERROR: CUDA RT call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
|
||||||
|
<< " failed with " << cudaGetErrorString(_status) << std::endl; \
|
||||||
|
TORCH_CHECK( \
|
||||||
|
false, \
|
||||||
|
c10::str( \
|
||||||
|
"ERROR: CUDA RT call \"", \
|
||||||
|
#call, \
|
||||||
|
"\" in line ", \
|
||||||
|
__LINE__, \
|
||||||
|
" of file ", \
|
||||||
|
__FILE__, \
|
||||||
|
" failed with ", \
|
||||||
|
cudaGetErrorString(_status))); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CUDA_DRV(call) \
|
||||||
|
do { \
|
||||||
|
CUresult _status = (call); \
|
||||||
|
if (_status != CUDA_SUCCESS) { \
|
||||||
|
const char* err_str; \
|
||||||
|
cuGetErrorString(_status, &err_str); \
|
||||||
|
std::cerr << "ERROR: CUDA DRV call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
|
||||||
|
<< " failed with " << err_str << std::endl; \
|
||||||
|
TORCH_CHECK( \
|
||||||
|
false, \
|
||||||
|
c10::str( \
|
||||||
|
"ERROR: CUDA DRV call \"", \
|
||||||
|
#call, \
|
||||||
|
"\" in line ", \
|
||||||
|
__LINE__, \
|
||||||
|
" of file ", \
|
||||||
|
__FILE__, \
|
||||||
|
" failed with ", \
|
||||||
|
err_str)); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
58
sgl-kernel/csrc/spatial/greenctx_stream.cu
Normal file
58
sgl-kernel/csrc/spatial/greenctx_stream.cu
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "greenctx_stream.h"
|
||||||
|
|
||||||
|
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
|
||||||
|
CUgreenCtx gctx[3];
|
||||||
|
CUdevResourceDesc desc[3];
|
||||||
|
CUdevResource input;
|
||||||
|
CUdevResource resources[4];
|
||||||
|
CUstream streamA;
|
||||||
|
CUstream streamB;
|
||||||
|
|
||||||
|
unsigned int nbGroups = 1;
|
||||||
|
|
||||||
|
if (smA <= 0 || smB <= 0) {
|
||||||
|
TORCH_CHECK(false, "SM counts must be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize device
|
||||||
|
CUDA_RT(cudaInitDevice(device, 0, 0));
|
||||||
|
|
||||||
|
// Query input SMs
|
||||||
|
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
|
||||||
|
// We want 3/4 the device for our green context
|
||||||
|
unsigned int minCount = (unsigned int)(smA + smB);
|
||||||
|
unsigned int minCountA = (unsigned int)(smA);
|
||||||
|
|
||||||
|
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
|
||||||
|
|
||||||
|
// Split resources
|
||||||
|
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount));
|
||||||
|
|
||||||
|
CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1));
|
||||||
|
CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||||
|
CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM));
|
||||||
|
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA));
|
||||||
|
|
||||||
|
CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1));
|
||||||
|
CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||||
|
CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1));
|
||||||
|
CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||||
|
|
||||||
|
CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0));
|
||||||
|
CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0));
|
||||||
|
|
||||||
|
int smCountA = resources[0].sm.smCount;
|
||||||
|
int smCountB = resources[1].sm.smCount;
|
||||||
|
|
||||||
|
CUDA_DRV(cuGreenCtxDestroy(gctx[2]));
|
||||||
|
|
||||||
|
std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB};
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
2
sgl-kernel/csrc/spatial/greenctx_stream.h
Normal file
2
sgl-kernel/csrc/spatial/greenctx_stream.h
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
#include <vector>
|
||||||
|
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
|
||||||
@@ -661,3 +661,8 @@ void qserve_w4a8_per_group_gemm(
|
|||||||
const torch::Tensor& _wscales,
|
const torch::Tensor& _wscales,
|
||||||
const torch::Tensor& _ascales,
|
const torch::Tensor& _ascales,
|
||||||
torch::Tensor& _out_feats);
|
torch::Tensor& _out_feats);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/spatial
|
||||||
|
*/
|
||||||
|
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ from sgl_kernel.sampling import (
|
|||||||
top_p_renorm_prob,
|
top_p_renorm_prob,
|
||||||
top_p_sampling_from_probs,
|
top_p_sampling_from_probs,
|
||||||
)
|
)
|
||||||
|
from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available
|
||||||
from sgl_kernel.speculative import (
|
from sgl_kernel.speculative import (
|
||||||
build_tree_kernel_efficient,
|
build_tree_kernel_efficient,
|
||||||
segment_packbits,
|
segment_packbits,
|
||||||
|
|||||||
48
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
48
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import torch
|
||||||
|
from torch.cuda.streams import ExternalStream
|
||||||
|
|
||||||
|
|
||||||
|
def create_greenctx_stream_by_value(
|
||||||
|
SM_a: int, SM_b: int, device_id: int = None
|
||||||
|
) -> tuple[ExternalStream, ExternalStream]:
|
||||||
|
"""
|
||||||
|
Create two streams for greenctx.
|
||||||
|
Args:
|
||||||
|
sm_A (int): The SM of stream A.
|
||||||
|
sm_B (int): The weight of stream B.
|
||||||
|
device_id (int): The device id.
|
||||||
|
Returns:
|
||||||
|
tuple[ExternalStream, ExternalStream]: The two streams.
|
||||||
|
"""
|
||||||
|
if device_id is None:
|
||||||
|
device_id = torch.cuda.current_device()
|
||||||
|
|
||||||
|
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
|
||||||
|
|
||||||
|
stream_a = ExternalStream(
|
||||||
|
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
|
||||||
|
)
|
||||||
|
stream_b = ExternalStream(
|
||||||
|
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
|
||||||
|
)
|
||||||
|
|
||||||
|
return stream_a, stream_b
|
||||||
|
|
||||||
|
|
||||||
|
def get_sm_available(device_id: int = None) -> int:
|
||||||
|
"""
|
||||||
|
Get the SMs available on the device.
|
||||||
|
Args:
|
||||||
|
device_id (int): The device id.
|
||||||
|
Returns:
|
||||||
|
int: The SMs available.
|
||||||
|
"""
|
||||||
|
if device_id is None:
|
||||||
|
device_id = torch.cuda.current_device()
|
||||||
|
|
||||||
|
device_props = torch.cuda.get_device_properties(device_id)
|
||||||
|
|
||||||
|
# Get the number of Streaming Multiprocessors (SMs)
|
||||||
|
sm_count = device_props.multi_processor_count
|
||||||
|
|
||||||
|
return sm_count
|
||||||
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
|
||||||
|
|
||||||
|
|
||||||
|
def test_green_ctx():
|
||||||
|
A = torch.randn(5120, 5120).cuda()
|
||||||
|
B = torch.randn(5120, 5120).cuda()
|
||||||
|
C = torch.matmul(A, B)
|
||||||
|
sm_counts = get_sm_available(0)
|
||||||
|
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
|
||||||
|
with torch.cuda.stream(stream_group[0]):
|
||||||
|
for _ in range(100):
|
||||||
|
result_0 = torch.matmul(A, B)
|
||||||
|
with torch.cuda.stream(stream_group[1]):
|
||||||
|
for _ in range(100):
|
||||||
|
result_1 = torch.matmul(A, B)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
assert torch.allclose(result_0, C)
|
||||||
|
assert torch.allclose(result_1, C)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user