From eac72f5f23ac271849bee934859bb8e1c1f16d39 Mon Sep 17 00:00:00 2001 From: zzhxxx <2783294813@qq.com> Date: Thu, 11 Dec 2025 12:43:04 +0800 Subject: [PATCH] [Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart PixPin_2025-11-14_13-37-39 ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ --- tests/ut/attention/test_mla_v1.py | 3 ++ tests/ut/distributed/test_parallel_state.py | 1 + vllm_ascend/ascend_config.py | 5 +-- vllm_ascend/attention/mla_v1.py | 34 +++++++++++++++-- vllm_ascend/distributed/parallel_state.py | 42 +++++++++++++++------ vllm_ascend/envs.py | 5 +++ vllm_ascend/ops/shared_weight_layer.py | 2 +- vllm_ascend/utils.py | 19 +++++++--- 8 files changed, 86 insertions(+), 25 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index e2f74655..97caf4b1 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import (AscendMLABackend, AscendMLADecodeMetadata, @@ -845,6 +846,8 @@ class TestAscendMLAImpl(TestBase): model_config.dtype = torch.float16 vllm_config.model_config = model_config get_current_vllm_config.return_value = vllm_config + vllm_config.additional_config = {"refresh": True} + init_ascend_config(vllm_config) num_heads = 256 head_size = 1024 diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 4a910916..c69f4449 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -46,6 +46,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_vllm_config.kv_transfer_config.is_kv_producer = True mock_envs_ascend = MagicMock() mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2 + mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0 mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index dd29d328..e1eaad1e 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -165,9 +165,8 @@ class AscendConfig: "Only support P node tp size lagger then D node tp size") self.SLO_limits_for_dynamic_batch = additional_config.get( "SLO_limits_for_dynamic_batch", -1) - from vllm_ascend.utils import \ - get_flashcomm2_oproj_tp_size_and_validate_config - self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( + from vllm_ascend.utils import get_flashcomm2_config_and_validate + self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate( self, vllm_config) self.enable_npugraph_ex = additional_config.get( "enable_npugraph_ex", False) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5e68b52e..e37b0d71 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -34,10 +34,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.compilation.acl_graph import (get_graph_params, get_mtp_graph_params, update_graph_params_workspaces) +from vllm_ascend.ops.shared_weight_layer import ( + is_hidden_layer, post_process_after_loading_for_shared_weight_series, + reach_layer_for_shared_weight_series, + register_layer_to_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_enable_nz, weak_ref_tensors) + flashcomm2_o_shared_enabled, is_enable_nz, + weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -848,6 +853,19 @@ class AscendMLAImpl(MLAAttentionImpl): 'q_b_proj'] self.kv_b_proj = kwargs['kv_b_proj'] self.o_proj = kwargs['o_proj'] + self.vllm_config = get_current_vllm_config() + self.fc2_o_shared_enable = flashcomm2_o_shared_enabled() + + if self.fc2_o_shared_enable and is_hidden_layer( + self.vllm_config, self.o_proj): + from vllm_ascend.distributed.parallel_state import \ + get_shared_weight_group + register_layer_to_shared_weight_series( + series_name="o_proj", + group=get_shared_weight_group(), + layer=self.o_proj, + prefetch_step=1) + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.q_a_layernorm = kwargs.get('q_a_layernorm', None) @@ -858,10 +876,9 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled - vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 - self.speculative_config = vllm_config.speculative_config + self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO self.pcp_size = get_pcp_group().world_size @@ -995,6 +1012,10 @@ class AscendMLAImpl(MLAAttentionImpl): if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) + if self.fc2_o_shared_enable and is_hidden_layer( + self.vllm_config, self.o_proj): + post_process_after_loading_for_shared_weight_series(self.o_proj) + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ ..., self.q_lora_rank:].contiguous() @@ -1515,6 +1536,10 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split.contiguous(), need_gather_q_kv) + if self.fc2_o_shared_enable and is_hidden_layer( + self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) + decode_preprocess_res = None prefill_preprocess_res = None if has_prefill: @@ -1633,6 +1658,9 @@ class AscendMLAImpl(MLAAttentionImpl): assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. + if self.fc2_o_shared_enable and is_hidden_layer( + self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) if self.pcp_size > 1: num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index b6979e58..c9bff649 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import enable_sp, flashcomm2_enable +from vllm_ascend.utils import (enable_sp, flashcomm2_enable, + flashcomm2_o_shared_enabled) # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -77,6 +78,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) + vllm_config = get_current_vllm_config() # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, @@ -182,6 +184,29 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="lmheadtp") + def _create_shared_weight_group(group_name: str) -> GroupCoordinator: + #This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference + group_ranks = [] + for pp_idx in range(global_pp_size): + group = [] + for dp_idx in range(global_dp_size): + base = (dp_idx * global_pp_size + pp_idx) * global_tp_size + for i in range(global_tp_size): + global_rank = base + i + group.append(global_rank) + group_ranks.append(group) + + return init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name=group_name) + + global _SHARED_WEIGHT + # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 + is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") + if enable_sp() and is_ds_v32: + _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight") + # TODO: Extract and unify the logic across different communication group. if flashcomm2_enable(): flashcomm2_otp_size = get_ascend_config( @@ -234,17 +259,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="flashcomm2_odp") - vllm_config = get_current_vllm_config() - # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 - is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") - if enable_sp() and is_ds_v32: - global _SHARED_WEIGHT - group_ranks = [list(range(torch.distributed.get_world_size()))] - _SHARED_WEIGHT = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - group_name="CP_shared_weight") + # Create shared weight group for flashcomm2 oproj + if flashcomm2_o_shared_enabled(): + assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" + _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared") def get_mlp_tensor_model_parallel_world_size(): diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 72db1791..7e0480c9 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -97,6 +97,11 @@ env_variables: Dict[str, Callable[[], Any]] = { # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), + # This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature, + # which can eliminate redundant storage of weights. More detailed information can be found in PR#4188. + # We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited. + "VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))), # Whether to enable MLP weight prefetch, only used in small concurrency. "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), diff --git a/vllm_ascend/ops/shared_weight_layer.py b/vllm_ascend/ops/shared_weight_layer.py index 99e92439..48a5179f 100644 --- a/vllm_ascend/ops/shared_weight_layer.py +++ b/vllm_ascend/ops/shared_weight_layer.py @@ -249,4 +249,4 @@ def reach_layer_for_shared_weight_series(layer: LinearBase): def is_hidden_layer(vllm_config, layer: LinearBase) -> bool: num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers layer_idx = extract_layer_index(layer.prefix) - return layer_idx < num_hidden_layers \ No newline at end of file + return layer_idx < num_hidden_layers diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d55a2af7..405fd685 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool: return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 -def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config, - vllm_config): +def flashcomm2_o_shared_enabled() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED + + +def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size + flashcomm2_oproj_shared = flashcomm2_o_shared_enabled() if not flashcomm2_enable(): - logger.debug("FLASHCOMM2 not enable.") - return flashcomm2_oproj_tp_size + flashcomm2_oproj_shared = False + logger.info("FLASHCOMM2 not enable.") + return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared logger.info( - f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}" + f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}" ) if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: logger.warning_once( @@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config, "FLASHCOMM2 primarily targets P-scenario deployments, " "with additional support for hybrid deployment scenarios. " "It is not applicable in D-scenario environments.") + if flashcomm2_oproj_shared: + logger.info("Enable FLASHCOMM2 with oproj_shared.") - return flashcomm2_oproj_tp_size + return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: