misc: add compute capability in check_env (#965)

This commit is contained in:
Yineng Zhang
2024-08-07 16:39:36 +08:00
committed by GitHub
parent 5f6fa04a3f
commit 05abd1261c

View File

@@ -73,10 +73,26 @@ def _get_gpu_info():
Get information about available GPUs.
"""
devices = defaultdict(list)
capabilities = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
capability = torch.cuda.get_device_capability(k)
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
gpu_info = {}
for name, device_ids in devices.items():
gpu_info[f"GPU {','.join(device_ids)}"] = name
if len(capabilities) == 1:
# All GPUs have the same compute capability
cap, gpu_ids = list(capabilities.items())[0]
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
else:
# GPUs have different compute capabilities
for cap, gpu_ids in capabilities.items():
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
return gpu_info
def _get_cuda_version_info():
@@ -118,6 +134,7 @@ def _get_cuda_driver_version():
"""
Get CUDA driver version.
"""
versions = set()
try:
output = subprocess.check_output(
[
@@ -126,7 +143,11 @@ def _get_cuda_driver_version():
"--format=csv,noheader,nounits",
]
)
return {"CUDA Driver Version": output.decode().strip()}
versions = set(output.decode().strip().split("\n"))
if len(versions) == 1:
return {"CUDA Driver Version": versions.pop()}
else:
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
except subprocess.SubprocessError:
return {"CUDA Driver Version": "Not Available"}