diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index c3cbc41fe..faeac0bba 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -185,9 +185,12 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert is_cuda() + if is_cuda(): + assert is_cuda() - full_nvlink = is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) + else: + full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23dcb43d2..f1d57e906 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -73,7 +73,7 @@ def is_hip() -> bool: def is_cuda(): - return hasattr(torch, "cuda") and torch.cuda.is_available() + return hasattr(torch, "cuda") and torch.version.cuda is not None def is_cuda_alike():