diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index c6d1a8307..d97c348ef 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -41,6 +41,7 @@ from torch.distributed import Backend, ProcessGroup from sglang.srt.utils import ( direct_register_custom_op, is_cuda_alike, + is_hip, supports_custom_op, ) @@ -952,6 +953,9 @@ _ENABLE_CUSTOM_ALL_REDUCE = True def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE _ENABLE_CUSTOM_ALL_REDUCE = enable + if enable and is_hip(): + logger.warning("HIP doesn't support custom_all_reduce, so disable it.") + _ENABLE_CUSTOM_ALL_REDUCE = False def init_distributed_environment(