[Fix] Not skip NVML Check on AMD Platform (#3135)
This commit is contained in:
@@ -185,9 +185,12 @@ class CustomAllreduce:
|
|||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
# this checks hardware and driver support for NVLink
|
# this checks hardware and driver support for NVLink
|
||||||
assert is_cuda()
|
if is_cuda():
|
||||||
|
assert is_cuda()
|
||||||
|
|
||||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||||
|
else:
|
||||||
|
full_nvlink = False
|
||||||
if world_size > 2 and not full_nvlink:
|
if world_size > 2 and not full_nvlink:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Custom allreduce is disabled because it's not supported on"
|
"Custom allreduce is disabled because it's not supported on"
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def is_hip() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_cuda():
|
def is_cuda():
|
||||||
return hasattr(torch, "cuda") and torch.cuda.is_available()
|
return hasattr(torch, "cuda") and torch.version.cuda is not None
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_alike():
|
def is_cuda_alike():
|
||||||
|
|||||||
Reference in New Issue
Block a user