From d2fdeac22f503939015fe31b42a07e4f5dea63f8 Mon Sep 17 00:00:00 2001 From: maxiao Date: Mon, 3 Nov 2025 16:28:21 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E7=94=A8vllm=E9=87=8Ccustom=20all=20r?= =?UTF-8?q?educe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/_custom_ops.py | 8 ++++++-- .../distributed/device_communicators/custom_all_reduce.py | 3 ++- python/sglang/srt/distributed/parallel_state.py | 1 - 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index cf63dd6c8..1f915e712 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var( if not is_hpu(): # ROCm does not use vllm custom allreduce - if use_vllm_custom_allreduce and not is_hip(): + # if use_vllm_custom_allreduce and not is_hip(): + if use_vllm_custom_allreduce: try: import vllm._C # noqa: F401 + print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)") except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) else: @@ -34,9 +36,11 @@ if not is_hpu(): logger.warning("Failed to import from custom_ar with %r", e) -if not is_hip() and not is_npu(): +# if not is_hip() and not is_npu(): +if not is_npu(): if use_vllm_custom_allreduce: custom_op = torch.ops._C_custom_ar + print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)") else: custom_op = sgl_kernel.allreduce diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 72668bf2e..38103a02e 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -27,7 +27,8 @@ _is_hip = is_hip() try: - if ops.use_vllm_custom_allreduce and not _is_hip: + # if ops.use_vllm_custom_allreduce and not _is_hip: + if ops.use_vllm_custom_allreduce: # Use vLLM custom allreduce ops.meta_size() else: diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 464bd2b17..e0d533479 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1539,7 +1539,6 @@ def initialize_model_parallel( group_name="tp", pynccl_use_current_stream=duplicate_tp_group, torch_compile=torch_compile, - use_custom_allreduce = False, ) if duplicate_tp_group: