misc: add compute capability in check_env (#965)
This commit is contained in:
@@ -73,10 +73,26 @@ def _get_gpu_info():
|
|||||||
Get information about available GPUs.
|
Get information about available GPUs.
|
||||||
"""
|
"""
|
||||||
devices = defaultdict(list)
|
devices = defaultdict(list)
|
||||||
|
capabilities = defaultdict(list)
|
||||||
for k in range(torch.cuda.device_count()):
|
for k in range(torch.cuda.device_count()):
|
||||||
devices[torch.cuda.get_device_name(k)].append(str(k))
|
devices[torch.cuda.get_device_name(k)].append(str(k))
|
||||||
|
capability = torch.cuda.get_device_capability(k)
|
||||||
|
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
||||||
|
|
||||||
return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
|
gpu_info = {}
|
||||||
|
for name, device_ids in devices.items():
|
||||||
|
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
||||||
|
|
||||||
|
if len(capabilities) == 1:
|
||||||
|
# All GPUs have the same compute capability
|
||||||
|
cap, gpu_ids = list(capabilities.items())[0]
|
||||||
|
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||||
|
else:
|
||||||
|
# GPUs have different compute capabilities
|
||||||
|
for cap, gpu_ids in capabilities.items():
|
||||||
|
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||||
|
|
||||||
|
return gpu_info
|
||||||
|
|
||||||
|
|
||||||
def _get_cuda_version_info():
|
def _get_cuda_version_info():
|
||||||
@@ -118,6 +134,7 @@ def _get_cuda_driver_version():
|
|||||||
"""
|
"""
|
||||||
Get CUDA driver version.
|
Get CUDA driver version.
|
||||||
"""
|
"""
|
||||||
|
versions = set()
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
[
|
[
|
||||||
@@ -126,7 +143,11 @@ def _get_cuda_driver_version():
|
|||||||
"--format=csv,noheader,nounits",
|
"--format=csv,noheader,nounits",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return {"CUDA Driver Version": output.decode().strip()}
|
versions = set(output.decode().strip().split("\n"))
|
||||||
|
if len(versions) == 1:
|
||||||
|
return {"CUDA Driver Version": versions.pop()}
|
||||||
|
else:
|
||||||
|
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
||||||
except subprocess.SubprocessError:
|
except subprocess.SubprocessError:
|
||||||
return {"CUDA Driver Version": "Not Available"}
|
return {"CUDA Driver Version": "Not Available"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user