From 0954fd0912a12b8f1c5320f15ad24ace774b84e5 Mon Sep 17 00:00:00 2001 From: aipaes <82140963+aipaes@users.noreply.github.com> Date: Fri, 17 Apr 2026 22:46:36 +0800 Subject: [PATCH] [BugFix][0.18.0] Fix quant_bias missing in w8a8_static when flashcomm1 is enabled for GLM-5 (#8304) ### What this PR does / why we need it? PR #8220 in v0.18.0 In a previous PR #7843 , the o_proj layer of GLM-5 was reverted to TP (Tensor Parallel) splitting when flashcomm1 was enabled. However, this was a temporary workaround and did not address the root cause of the precision issues observed in the o_proj layer under flashcomm1. I am working on a definitive fix for this issue. Currently, a clear bug has been identified in https://github.com/vllm-project/vllm-ascend/blob/880e20fdde0b4e903685741b06e23b5fb8c287a1/vllm_ascend/quantization/methods/w8a8_static.py#L124: during quantized matrix multiplication, quant_bias is not added if tp_rank > 0. In the flashcomm1 scenario, all ranks actually require the addition of quant_bias, meaning tp_rank=0 should be passed to ensure the bias is applied correctly. This PR aims to resolve this logic error and fix the underlying precision issue. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? glm5 e2e test --------- Signed-off-by: zjks98 Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: triomino <15924998+triomino@users.noreply.github.com> Co-authored-by: zjks98 --- vllm_ascend/attention/sfa_v1.py | 17 ----------------- vllm_ascend/ops/linear_op.py | 6 +----- vllm_ascend/quantization/method_adapters.py | 5 ++++- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 692b3444..692566bf 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -445,13 +445,6 @@ class AscendSFAImpl(MLAAttentionImpl): # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() - # Improves glm5 accuracy after enabling dsa-cp in scenarios with strict accuracy requirements, - # especially for customized cases, at the cost of performance degradation due to extra communication. - self.enable_dsa_cp_strict_accuracy = ( - self.enable_dsa_cp_with_layer_shard - and self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"] - ) - # use original TP o_proj weight in PD mix stage, and full gather # for o_proj weight for prefill stage. self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() @@ -1246,16 +1239,6 @@ class AscendSFAImpl(MLAAttentionImpl): return result attn_output = result - if self.enable_dsa_cp_strict_accuracy: - send = ( - attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) - .permute(1, 0, 2) - .reshape(-1, self.num_heads * self.v_head_dim) - ) - - attn_output = torch.empty_like(send) - torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) - output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ffc5fbba..85d29d3f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -663,11 +663,7 @@ def _get_row_parallel_op( | None ): if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: - from vllm.config import get_current_vllm_config - - vllm_config = get_current_vllm_config() - if vllm_config.model_config.hf_config.model_type not in ["glm_moe_dsa"]: - return ShardedCPRowParallelOp(layer) + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index 3c215860..7d52f8ca 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -29,7 +29,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group -from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable +from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable, mlp_tp_enable, oproj_tp_enable from .methods import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, is_mx_quant_type @@ -46,6 +46,7 @@ class AscendLinearMethod(LinearMethodBase): def __init__(self, scheme: AscendLinearScheme) -> None: self.quant_method = scheme + self._enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() def create_weights( self, @@ -145,6 +146,8 @@ class AscendLinearMethod(LinearMethodBase): tp_rank = 0 else: tp_rank = get_flashcomm2_otp_group().rank_in_group + elif layer.prefix.find("o_proj") != -1 and self._enable_dsa_cp_with_layer_shard: + tp_rank = 0 else: tp_rank = get_tensor_model_parallel_rank() else: