From eac72f5f23ac271849bee934859bb8e1c1f16d39 Mon Sep 17 00:00:00 2001
From: zzhxxx <2783294813@qq.com>
Date: Thu, 11 Dec 2025 12:43:04 +0800
Subject: [PATCH] [Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it?
It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.
This PR depends on #3232 and #2931
### Flashcomm2 flowchart
### Does this PR introduce _any_ user-facing change?
Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: zzhx1
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ
---
tests/ut/attention/test_mla_v1.py | 3 ++
tests/ut/distributed/test_parallel_state.py | 1 +
vllm_ascend/ascend_config.py | 5 +--
vllm_ascend/attention/mla_v1.py | 34 +++++++++++++++--
vllm_ascend/distributed/parallel_state.py | 42 +++++++++++++++------
vllm_ascend/envs.py | 5 +++
vllm_ascend/ops/shared_weight_layer.py | 2 +-
vllm_ascend/utils.py | 19 +++++++---
8 files changed, 86 insertions(+), 25 deletions(-)
diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py
index e2f74655..97caf4b1 100644
--- a/tests/ut/attention/test_mla_v1.py
+++ b/tests/ut/attention/test_mla_v1.py
@@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
+from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLADecodeMetadata,
@@ -845,6 +846,8 @@ class TestAscendMLAImpl(TestBase):
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
+ vllm_config.additional_config = {"refresh": True}
+ init_ascend_config(vllm_config)
num_heads = 256
head_size = 1024
diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py
index 4a910916..c69f4449 100644
--- a/tests/ut/distributed/test_parallel_state.py
+++ b/tests/ut/distributed/test_parallel_state.py
@@ -46,6 +46,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
+ mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py
index dd29d328..e1eaad1e 100644
--- a/vllm_ascend/ascend_config.py
+++ b/vllm_ascend/ascend_config.py
@@ -165,9 +165,8 @@ class AscendConfig:
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
- from vllm_ascend.utils import \
- get_flashcomm2_oproj_tp_size_and_validate_config
- self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
+ from vllm_ascend.utils import get_flashcomm2_config_and_validate
+ self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
self, vllm_config)
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)
diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
index 5e68b52e..e37b0d71 100644
--- a/vllm_ascend/attention/mla_v1.py
+++ b/vllm_ascend/attention/mla_v1.py
@@ -34,10 +34,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
+from vllm_ascend.ops.shared_weight_layer import (
+ is_hidden_layer, post_process_after_loading_for_shared_weight_series,
+ reach_layer_for_shared_weight_series,
+ register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
- is_enable_nz, weak_ref_tensors)
+ flashcomm2_o_shared_enabled, is_enable_nz,
+ weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -848,6 +853,19 @@ class AscendMLAImpl(MLAAttentionImpl):
'q_b_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
+ self.vllm_config = get_current_vllm_config()
+ self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
+
+ if self.fc2_o_shared_enable and is_hidden_layer(
+ self.vllm_config, self.o_proj):
+ from vllm_ascend.distributed.parallel_state import \
+ get_shared_weight_group
+ register_layer_to_shared_weight_series(
+ series_name="o_proj",
+ group=get_shared_weight_group(),
+ layer=self.o_proj,
+ prefetch_step=1)
+
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
@@ -858,10 +876,9 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
- vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
- self.speculative_config = vllm_config.speculative_config
+ self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.pcp_size = get_pcp_group().world_size
@@ -995,6 +1012,10 @@ class AscendMLAImpl(MLAAttentionImpl):
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
+ if self.fc2_o_shared_enable and is_hidden_layer(
+ self.vllm_config, self.o_proj):
+ post_process_after_loading_for_shared_weight_series(self.o_proj)
+
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
@@ -1515,6 +1536,10 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
+ if self.fc2_o_shared_enable and is_hidden_layer(
+ self.vllm_config, self.o_proj):
+ reach_layer_for_shared_weight_series(self.o_proj)
+
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
@@ -1633,6 +1658,9 @@ class AscendMLAImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
+ if self.fc2_o_shared_enable and is_hidden_layer(
+ self.vllm_config, self.o_proj):
+ reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py
index b6979e58..c9bff649 100644
--- a/vllm_ascend/distributed/parallel_state.py
+++ b/vllm_ascend/distributed/parallel_state.py
@@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
-from vllm_ascend.utils import enable_sp, flashcomm2_enable
+from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
+ flashcomm2_o_shared_enabled)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -77,6 +78,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
+ vllm_config = get_current_vllm_config()
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
@@ -182,6 +184,29 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="lmheadtp")
+ def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
+ #This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
+ group_ranks = []
+ for pp_idx in range(global_pp_size):
+ group = []
+ for dp_idx in range(global_dp_size):
+ base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
+ for i in range(global_tp_size):
+ global_rank = base + i
+ group.append(global_rank)
+ group_ranks.append(group)
+
+ return init_model_parallel_group(group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name=group_name)
+
+ global _SHARED_WEIGHT
+ # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
+ is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
+ if enable_sp() and is_ds_v32:
+ _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
+
# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
@@ -234,17 +259,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="flashcomm2_odp")
- vllm_config = get_current_vllm_config()
- # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
- is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
- if enable_sp() and is_ds_v32:
- global _SHARED_WEIGHT
- group_ranks = [list(range(torch.distributed.get_world_size()))]
- _SHARED_WEIGHT = init_model_parallel_group(
- group_ranks,
- get_world_group().local_rank,
- backend,
- group_name="CP_shared_weight")
+ # Create shared weight group for flashcomm2 oproj
+ if flashcomm2_o_shared_enabled():
+ assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
+ _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
def get_mlp_tensor_model_parallel_world_size():
diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py
index 72db1791..7e0480c9 100644
--- a/vllm_ascend/envs.py
+++ b/vllm_ascend/envs.py
@@ -97,6 +97,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
+ # This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
+ # which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
+ # We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
+ "VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
+ lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
diff --git a/vllm_ascend/ops/shared_weight_layer.py b/vllm_ascend/ops/shared_weight_layer.py
index 99e92439..48a5179f 100644
--- a/vllm_ascend/ops/shared_weight_layer.py
+++ b/vllm_ascend/ops/shared_weight_layer.py
@@ -249,4 +249,4 @@ def reach_layer_for_shared_weight_series(layer: LinearBase):
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
layer_idx = extract_layer_index(layer.prefix)
- return layer_idx < num_hidden_layers
\ No newline at end of file
+ return layer_idx < num_hidden_layers
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index d55a2af7..405fd685 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
-def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
- vllm_config):
+def flashcomm2_o_shared_enabled() -> bool:
+ return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
+
+
+def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
+ flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
if not flashcomm2_enable():
- logger.debug("FLASHCOMM2 not enable.")
- return flashcomm2_oproj_tp_size
+ flashcomm2_oproj_shared = False
+ logger.info("FLASHCOMM2 not enable.")
+ return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
logger.info(
- f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
+ f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
)
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
@@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
"FLASHCOMM2 primarily targets P-scenario deployments, "
"with additional support for hybrid deployment scenarios. "
"It is not applicable in D-scenario environments.")
+ if flashcomm2_oproj_shared:
+ logger.info("Enable FLASHCOMM2 with oproj_shared.")
- return flashcomm2_oproj_tp_size
+ return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: