Sync from v0.13
This commit is contained in:
81
tests/cuda/test_cuda_context.py
Normal file
81
tests/cuda/test_cuda_context.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def check_cuda_context():
|
||||
"""Check CUDA driver context status"""
|
||||
try:
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
device = ctypes.c_int()
|
||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||
return (True, device.value) if result == 0 else (False, None)
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
|
||||
def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
"""Run CUDA context test in separate thread for isolation"""
|
||||
try:
|
||||
# New thread should have no CUDA context initially
|
||||
valid_before, device_before = check_cuda_context()
|
||||
if valid_before:
|
||||
return (
|
||||
False,
|
||||
"CUDA context should not exist in new thread, "
|
||||
f"got device {device_before}",
|
||||
)
|
||||
|
||||
# Test setting CUDA context
|
||||
current_platform.set_device(device_input)
|
||||
|
||||
# Verify context is created correctly
|
||||
valid_after, device_id = check_cuda_context()
|
||||
if not valid_after:
|
||||
return False, "CUDA context should be valid after set_cuda_context"
|
||||
if device_id != expected_device_id:
|
||||
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||
|
||||
return True, "Success"
|
||||
except Exception as e:
|
||||
return False, f"Exception in thread: {str(e)}"
|
||||
|
||||
|
||||
class TestSetCudaContext:
|
||||
"""Test suite for the set_cuda_context function."""
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device("cuda:0"), 0),
|
||||
("cuda:0", 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"],
|
||||
)
|
||||
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||
"""Test setting CUDA context in isolated threads."""
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(
|
||||
run_cuda_test_in_thread, device_input, expected_device_id
|
||||
)
|
||||
success, message = future.result(timeout=30)
|
||||
assert success, message
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_set_cuda_context_invalid_device_type(self):
|
||||
"""Test error handling for invalid device type."""
|
||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||
current_platform.set_device(torch.device("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user