From b1cc6ef6aef0593fbfe03b40021197ddf281b76c Mon Sep 17 00:00:00 2001 From: yydyzr Date: Tue, 31 Mar 2026 21:51:10 +0800 Subject: [PATCH] [v0.18.0][BugFix] Fix bug of precision when DSA-CP is enabled on GLM5 (#7843) ### What this PR does / why we need it? This PR fixs accuracy bug in some cases with additional communication methods. This PR is a specific fix for version 0.18.0 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/35141a7eeda941a60ad5a4956670c60fd5a77029 --------- Signed-off-by: rjg-lyh <1318825571@qq.com> Signed-off-by: yydyzr Co-authored-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/attention/sfa_v1.py | 17 +++++++++++++++++ vllm_ascend/ops/linear_op.py | 6 +++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 692566bf..692b3444 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -445,6 +445,13 @@ class AscendSFAImpl(MLAAttentionImpl): # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() + # Improves glm5 accuracy after enabling dsa-cp in scenarios with strict accuracy requirements, + # especially for customized cases, at the cost of performance degradation due to extra communication. + self.enable_dsa_cp_strict_accuracy = ( + self.enable_dsa_cp_with_layer_shard + and self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"] + ) + # use original TP o_proj weight in PD mix stage, and full gather # for o_proj weight for prefill stage. self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() @@ -1239,6 +1246,16 @@ class AscendSFAImpl(MLAAttentionImpl): return result attn_output = result + if self.enable_dsa_cp_strict_accuracy: + send = ( + attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) + .permute(1, 0, 2) + .reshape(-1, self.num_heads * self.v_head_dim) + ) + + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) + output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 85d29d3f..ffc5fbba 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -663,7 +663,11 @@ def _get_row_parallel_op( | None ): if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: - return ShardedCPRowParallelOp(layer) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.model_config.hf_config.model_type not in ["glm_moe_dsa"]: + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable():