From 66283dbc0c052c6f32bde68451addc5b0d00cf3b Mon Sep 17 00:00:00 2001 From: yigex Date: Sun, 26 Jan 2025 13:33:51 +0800 Subject: [PATCH] [Fix] Not skip NVML Check on AMD Platform (#3135) --- .../distributed/device_communicators/custom_all_reduce.py | 7 +++++-- python/sglang/srt/utils.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index c3cbc41fe..faeac0bba 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -185,9 +185,12 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert is_cuda() + if is_cuda(): + assert is_cuda() - full_nvlink = is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) + else: + full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23dcb43d2..f1d57e906 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -73,7 +73,7 @@ def is_hip() -> bool: def is_cuda(): - return hasattr(torch, "cuda") and torch.cuda.is_available() + return hasattr(torch, "cuda") and torch.version.cuda is not None def is_cuda_alike():