diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index cf63dd6c8..1f915e712 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var( if not is_hpu(): # ROCm does not use vllm custom allreduce - if use_vllm_custom_allreduce and not is_hip(): + # if use_vllm_custom_allreduce and not is_hip(): + if use_vllm_custom_allreduce: try: import vllm._C # noqa: F401 + print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)") except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) else: @@ -34,9 +36,11 @@ if not is_hpu(): logger.warning("Failed to import from custom_ar with %r", e) -if not is_hip() and not is_npu(): +# if not is_hip() and not is_npu(): +if not is_npu(): if use_vllm_custom_allreduce: custom_op = torch.ops._C_custom_ar + print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)") else: custom_op = sgl_kernel.allreduce diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 72668bf2e..38103a02e 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -27,7 +27,8 @@ _is_hip = is_hip() try: - if ops.use_vllm_custom_allreduce and not _is_hip: + # if ops.use_vllm_custom_allreduce and not _is_hip: + if ops.use_vllm_custom_allreduce: # Use vLLM custom allreduce ops.meta_size() else: diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 464bd2b17..e0d533479 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1539,7 +1539,6 @@ def initialize_model_parallel( group_name="tp", pynccl_use_current_stream=duplicate_tp_group, torch_compile=torch_compile, - use_custom_allreduce = False, ) if duplicate_tp_group: