diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 692566bf..692b3444 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -445,6 +445,13 @@ class AscendSFAImpl(MLAAttentionImpl): # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() + # Improves glm5 accuracy after enabling dsa-cp in scenarios with strict accuracy requirements, + # especially for customized cases, at the cost of performance degradation due to extra communication. + self.enable_dsa_cp_strict_accuracy = ( + self.enable_dsa_cp_with_layer_shard + and self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"] + ) + # use original TP o_proj weight in PD mix stage, and full gather # for o_proj weight for prefill stage. self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() @@ -1239,6 +1246,16 @@ class AscendSFAImpl(MLAAttentionImpl): return result attn_output = result + if self.enable_dsa_cp_strict_accuracy: + send = ( + attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) + .permute(1, 0, 2) + .reshape(-1, self.num_heads * self.v_head_dim) + ) + + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) + output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 85d29d3f..ffc5fbba 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -663,7 +663,11 @@ def _get_row_parallel_op( | None ): if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: - return ShardedCPRowParallelOp(layer) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.model_config.hf_config.model_type not in ["glm_moe_dsa"]: + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable():