[Feature] CUDA Green Context Support (#7649)
This commit is contained in:
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
|
||||
|
||||
|
||||
def test_green_ctx():
|
||||
A = torch.randn(5120, 5120).cuda()
|
||||
B = torch.randn(5120, 5120).cuda()
|
||||
C = torch.matmul(A, B)
|
||||
sm_counts = get_sm_available(0)
|
||||
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
|
||||
with torch.cuda.stream(stream_group[0]):
|
||||
for _ in range(100):
|
||||
result_0 = torch.matmul(A, B)
|
||||
with torch.cuda.stream(stream_group[1]):
|
||||
for _ in range(100):
|
||||
result_1 = torch.matmul(A, B)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.allclose(result_0, C)
|
||||
assert torch.allclose(result_1, C)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user