diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 692b3444..692566bf 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -445,13 +445,6 @@ class AscendSFAImpl(MLAAttentionImpl): # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() - # Improves glm5 accuracy after enabling dsa-cp in scenarios with strict accuracy requirements, - # especially for customized cases, at the cost of performance degradation due to extra communication. - self.enable_dsa_cp_strict_accuracy = ( - self.enable_dsa_cp_with_layer_shard - and self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"] - ) - # use original TP o_proj weight in PD mix stage, and full gather # for o_proj weight for prefill stage. self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() @@ -1246,16 +1239,6 @@ class AscendSFAImpl(MLAAttentionImpl): return result attn_output = result - if self.enable_dsa_cp_strict_accuracy: - send = ( - attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) - .permute(1, 0, 2) - .reshape(-1, self.num_heads * self.v_head_dim) - ) - - attn_output = torch.empty_like(send) - torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) - output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ffc5fbba..85d29d3f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -663,11 +663,7 @@ def _get_row_parallel_op( | None ): if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: - from vllm.config import get_current_vllm_config - - vllm_config = get_current_vllm_config() - if vllm_config.model_config.hf_config.model_type not in ["glm_moe_dsa"]: - return ShardedCPRowParallelOp(layer) + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index 3c215860..7d52f8ca 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -29,7 +29,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group -from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable +from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable, mlp_tp_enable, oproj_tp_enable from .methods import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, is_mx_quant_type @@ -46,6 +46,7 @@ class AscendLinearMethod(LinearMethodBase): def __init__(self, scheme: AscendLinearScheme) -> None: self.quant_method = scheme + self._enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() def create_weights( self, @@ -145,6 +146,8 @@ class AscendLinearMethod(LinearMethodBase): tp_rank = 0 else: tp_rank = get_flashcomm2_otp_group().rank_in_group + elif layer.prefix.find("o_proj") != -1 and self._enable_dsa_cp_with_layer_shard: + tp_rank = 0 else: tp_rank = get_tensor_model_parallel_rank() else: