From 44a4ff6960b9d4edbfd8df52695cdb1655009f39 Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Fri, 23 Jan 2026 21:12:23 +0800 Subject: [PATCH] [main][BugFix] Avoided a bug of `torch_npu.npu_mm_reduce_scatter_base` when sp size >= 16 (#6168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? If `sp` is enabled and `tp_size` >= 16, `torch_npu.npu_mm_reduce_scatter_base` will raises a exception. After consulting with the operator developer, we learned that the operator does not work when `tp` = 16. So, we disable the operator when `tp` = 16. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested We started a server with `sp` enabled and `tp` = 16. It started successfully. ```text (APIServer pid=1855938) INFO: Started server process [1855938] (APIServer pid=1855938) INFO: Waiting for application startup. (APIServer pid=1855938) INFO: Application startup complete. ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: drslark --- vllm_ascend/ascend_forward_context.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 49693474..ea137b04 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -72,13 +72,16 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False + # TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16. + mmrs_fusion = tp_world_size <= 8 + # set for sequence parallelism, 1000 is the batch size concurrency threshold # for enabling the flashcomm_v1 or sequence_parallelism feature. # Currently, it is an empirical value. In normal scenarios, if the concurrency # exceeds this threshold, the performance benefits can be maximized. # Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - mmrs_fusion = True + # main model and drafter model may have different architecture is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) if is_context_moe_model: @@ -86,6 +89,7 @@ def set_ascend_forward_context( mmrs_fusion = False else: sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 + forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens forward_context.sp_enabled = sp_enabled