From 5e24b26a543ae7ab92f7304c127ef27d43420369 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sun, 1 Mar 2026 20:22:50 +0800 Subject: [PATCH] [Bugfix] rename enable_flash_comm_v1 back to enable_sp (#6883) ### What this PR does / why we need it? PR #5632 introduced a bug by replacing some branches gated by enable_sp with enable_flash_comm_v1. As a result, when enable_shared_expert_dp is enabled alone (i.e., VLLM_ASCEND_ENABLE_FLASHCOMM1=0 and VLLM_ASCEND_ENABLE_FLASHCOMM=0), the behavior becomes inconsistent with the previous logic and leads to accuracy issues. This PR restores the original enable_sp-based branching to recover expected behavior and accuracy. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? #### 1. start server ``` bash vllm serve /home/weights/DeepSeek-V2-Lite-W8A8/ \ --port 8001 \ --served-model-name auto \ --max-model-len 1024 \ --enforce-eager \ --tensor-parallel-size 2 \ --data-parallel-size 2 \ --gpu-memory-utilization 0.9 \ --enable-expert-parallel \ --additional-config '{"enable_shared_expert_dp": true}' ``` #### 2. curl ```bash curl -s http://localhost:8001/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "auto", "messages": [ {"role": "user", "content": "Hello. I have a question. Who are you?"} ], "max_tokens": 10, "temperature": 0.0, "ignore_eos_token": true }' ``` - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 Signed-off-by: realliujiaxu --- vllm_ascend/ascend_forward_context.py | 6 +++--- vllm_ascend/ops/linear.py | 4 ++-- vllm_ascend/ops/linear_op.py | 8 ++++---- vllm_ascend/platform.py | 5 ++--- vllm_ascend/utils.py | 18 +++++++----------- vllm_ascend/worker/model_runner_v1.py | 6 +++--- vllm_ascend/worker/worker.py | 6 +++--- 7 files changed, 24 insertions(+), 29 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index b3426451..a889b84c 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ( AscendDeviceType, - enable_flash_comm_v1, + enable_sp, flashcomm2_enable, get_ascend_device_type, has_layer_idx, @@ -92,14 +92,14 @@ def set_ascend_forward_context( # main model and drafter model may have different architecture is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) if is_context_moe_model: - flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None + flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None mmrs_fusion = False elif is_draft_model: # TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`. # Disable it to avoid more problems. flash_comm_v1_enabled = False else: - flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000 + flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 53bd6e06..b2118606 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op -from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz +from vllm_ascend.utils import enable_sp, maybe_trans_nz class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): @@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear): disable_tp: bool = False, ): # TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM. - if enable_flash_comm_v1(): + if enable_sp(): compilation_config = get_current_vllm_config().compilation_config unique_prefix = prefix if prefix in compilation_config.static_forward_context: diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 9c78418f..6706313b 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import ( enable_dsa_cp, enable_dsa_cp_with_layer_shard, - enable_flash_comm_v1, + enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, get_weight_prefetch_method, @@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): # Matrix multiply. assert self.quant_method is not None - if enable_flash_comm_v1(): + if enable_sp(): input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) # Trigger async broadcast before matmul to overlap communication. @@ -649,7 +649,7 @@ def _get_column_parallel_op( if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")): return Flashcomm2OshardQKVParallelOp(layer) - if enable_flash_comm_v1(): + if enable_sp(): if "shared_expert" in prefix: return None sp_column_prefix = [ @@ -688,7 +688,7 @@ def _get_row_parallel_op( if flashcomm2_enable(): if "o_proj" in prefix or "out_proj" in prefix: return Flashcomm2OProjRowParallelOp(layer) - if enable_flash_comm_v1(): + if enable_sp(): if "shared_expert" in prefix: return None sp_row_prefixes = [ diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d3c926b6..c0e52984 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -39,7 +39,6 @@ from vllm_ascend.utils import ( COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config, - enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model, @@ -48,7 +47,7 @@ from vllm_ascend.utils import ( update_aclgraph_sizes, update_cudagraph_capture_sizes, is_310p, - enable_flash_comm_v1, + enable_sp, ) if TYPE_CHECKING: @@ -402,7 +401,7 @@ class NPUPlatform(Platform): ) vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size - if enable_flash_comm_v1(): + if enable_sp(vllm_config): assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \ Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \ For optimal performance with VL models, we recommend enabling Sequence Parallelism \ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 792af587..a1f7083e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -719,15 +719,6 @@ def matmul_allreduce_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE -def enable_flash_comm_v1(): - return ( - envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 - # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 - # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. - or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) - ) - - def enable_sp_by_pass(vllm_config: VllmConfig): return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp @@ -739,7 +730,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - _ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1() + _ENABLE_SP = ( + envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 + # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 + # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. + or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) + ) if not _ENABLE_SP and enable_shared_expert_dp: _ENABLE_SP = True @@ -1104,7 +1100,7 @@ def enable_dsa_cp() -> bool: is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr( vllm_config.model_config.hf_text_config, "index_topk" ) - return bool(is_ds_v32 and enable_flash_comm_v1()) + return bool(is_ds_v32 and enable_sp()) @lru_cache(maxsize=1) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ef80131f..4ed3f587 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -113,8 +113,8 @@ from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import ( check_gdn_layer, - enable_flash_comm_v1, enable_sp, + enable_sp_by_pass, is_drafter_moe_model, is_moe_model, lmhead_tp_enable, @@ -1745,7 +1745,7 @@ class NPUModelRunner(GPUModelRunner): # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if enable_sp(self.vllm_config): + if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2300,7 +2300,7 @@ class NPUModelRunner(GPUModelRunner): # tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading # to incorrect memory estimation and potentially causing OOM. intermediate_tokens = num_tokens_padded - if enable_flash_comm_v1(): + if enable_sp(): tp_size = get_tensor_model_parallel_world_size() intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size if self.intermediate_tensors is None: diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 1cadf05b..5e65521f 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.utils import ( AscendDeviceType, check_ascend_device_type, - enable_flash_comm_v1, + enable_sp, get_ascend_device_type, register_ascend_customop, ) @@ -376,7 +376,7 @@ class NPUWorker(WorkerBase): if forward_pass and not get_pp_group().is_first_rank: # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise # it will conflict with the all-gather operation in flashcomm1. - if enable_flash_comm_v1(): + if enable_sp(): all_gather_group = None else: all_gather_group = get_tp_group() @@ -393,7 +393,7 @@ class NPUWorker(WorkerBase): assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise # it will conflict with the all-gather operation in flashcomm1. - if enable_flash_comm_v1(): + if enable_sp(): all_gather_group = None else: all_gather_group = get_tp_group()