diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index b30a9834..43023b6b 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -100,8 +100,22 @@ class TestAscendSFAMetadataBuilder(TestBase): assert builder.device == device assert builder.vllm_config == vllm_config + @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") - def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla): + @patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp") + def test_ascend_sfa_metadata_builder_build( + self, + mock_enable_dsa_cp, + mock_get_cos_and_sin_mla, + mock_get_current_vllm_config, + ): + mock_enable_dsa_cp.return_value = False + + cfg = MagicMock() + cfg.model_config = MagicMock() + cfg.model_config.hf_text_config = MagicMock() + + mock_get_current_vllm_config.return_value = cfg kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() @@ -144,9 +158,16 @@ class TestAscendSFAMetadataBuilder(TestBase): assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens assert metadata.slot_mapping.shape == (100, 4, 1024) + @patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config") @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") def test_ascend_sfa_metadata_builder_build_for_graph_capture( - self, mock_get_cos_and_sin_mla): + self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config): + cfg = MagicMock() + cfg.model_config = MagicMock() + cfg.model_config.hf_text_config = MagicMock() + + mock_get_current_vllm_config.return_value = cfg + kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index d3b5b4b2..a32817dc 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -10,8 +10,7 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.linear import (ReplicatedLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -34,7 +33,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, - enable_sp, maybe_trans_nz, replace_layer) + enable_dsa_cp, maybe_trans_nz) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -149,8 +148,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): got {self.decode_threshold}" self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.enable_sfa_cp = enable_sp() and \ - hasattr(self.model_config.hf_text_config, "index_topk") + self.enable_sfa_cp = enable_dsa_cp() assert not ( self.enable_sfa_cp @@ -368,13 +366,11 @@ class AscendSFAImpl(MLAAttentionImpl): assert self.indexer is not None, "Indexer is required for DSA." - self.enable_sfa_cp = enable_sp() + self.enable_sfa_cp = enable_dsa_cp() self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size - - self._replace_linear_class_for_sfa_cp() self.layer_sharding_kwargs = [] for layer_name in (get_ascend_config().layer_sharding or []): if layer_name in kwargs: @@ -925,42 +921,3 @@ class AscendSFAImpl(MLAAttentionImpl): sparse_count=2048, sparse_mode=3) return topk_indices - - def _replace_linear_class_for_sfa_cp(self): - - vllm_config = get_current_vllm_config() - # Dispose tensor from the original q_proj - dispose_layer(self.q_proj) - # Construct the new q_proj using ReplicatedLinear - new_q_proj = ReplicatedLinear(self.q_lora_rank, - self.local_num_heads * self.qk_head_dim, - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.q_proj.prefix) - # Replace the q_proj with the new one - replace_layer(self.q_proj, new_q_proj) - - # Dispose tensor from the original kv_b_proj - dispose_layer(self.kv_b_proj) - # Construct the new kv_b_proj using ReplicatedLinear - new_kv_b_proj = ReplicatedLinear( - self.kv_lora_rank, - self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.kv_b_proj.prefix) - # Replace the kv_b_proj with the new one - replace_layer(self.kv_b_proj, new_kv_b_proj) - - # Dispose tensor from the original o_proj - dispose_layer(self.o_proj) - # Construct the new o_proj using ReplicatedLinear - config = vllm_config.model_config.hf_text_config - new_o_proj = ReplicatedLinear(config.num_attention_heads * - config.v_head_dim, - config.hidden_size, - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.o_proj.prefix) - # Replace the o_proj with the new one - replace_layer(self.o_proj, new_o_proj) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 53130e67..8d8ecbe0 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -38,6 +38,7 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp an import re from functools import lru_cache +from types import SimpleNamespace from typing import Optional, Union import torch @@ -59,7 +60,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (enable_sp, flashcomm2_enable, +from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) @@ -609,9 +610,60 @@ class SequenceRowParallelOp(CustomRowParallelOp): self.unique_prefix = self.layer.unique_prefix +class ShardedCPRowParallelOp(CustomRowParallelOp): + + @property + def comm_group(self): + # fake comm group to bypass tp logic + return SimpleNamespace(world_size=1, + rank_in_group=0, + device_group=None) + + def apply_impl( + self, + input_, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, input_, bias_) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.layer.reduce_results = False + + +class ShardedCPColumnParallelOp(CustomColumnParallelOp): + + @property + def comm_group(self): + # fake comm group to bypass tp logic + return SimpleNamespace(world_size=1, + rank_in_group=0, + device_group=None) + + def apply_impl( + self, + input_, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def _get_column_parallel_op( - prefix, layer -) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: + prefix, layer +) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, + ShardedCPColumnParallelOp]]: + if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix): + return ShardedCPColumnParallelOp(layer) if "gate_up_proj" in prefix and mlp_tp_enable( ) and not is_moe_layer(prefix): return MLPColumnParallelOp(layer) @@ -636,7 +688,9 @@ def _get_row_parallel_op( prefix, layer ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]]: + SequenceRowParallelOp, ShardedCPRowParallelOp]]: + if enable_dsa_cp() and "o_proj" in prefix: + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): @@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct): MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]] = None + SequenceRowParallelOp, ShardedCPRowParallelOp, + ShardedCPColumnParallelOp]] = None if direct == "row": custom_op = _get_row_parallel_op(prefix, layer) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 09266716..9874661e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any): dispose_tensor(attr_value) -def replace_layer(original_layer: Any, new_layer: Any): - original_layer.__class__ = new_layer.__class__ - original_layer.__dict__ = new_layer.__dict__ - - def check_kv_extra_config(vllm_config): def _check(name: str, config: dict): @@ -1166,17 +1161,31 @@ def singleton(cls): return get_instance -@lru_cache(maxsize=1) -def get_current_model_config(): - from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - return vllm_config.model_config - - #TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1 @lru_cache(maxsize=1) def enable_dsa_cp() -> bool: from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() - is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") - return is_ds_v32 and enable_sp() + if vllm_config is None: + return False + + model_config = getattr(vllm_config, "model_config", None) + if model_config is None: + return False + + hf_text_config = getattr(model_config, "hf_text_config", None) + if hf_text_config is None: + return False + + return hasattr(hf_text_config, "index_topk") + + +@lru_cache(maxsize=1) +def enable_dsa_cp_with_layer_shard() -> bool: + if not enable_dsa_cp(): + return False + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer + return is_prefill_instance