[Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it?
It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.
This PR depends on #3232 and #2931
### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>
### Does this PR introduce _any_ user-facing change?
Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
||||
AscendMLADecodeMetadata,
|
||||
@@ -845,6 +846,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
model_config.dtype = torch.float16
|
||||
vllm_config.model_config = model_config
|
||||
get_current_vllm_config.return_value = vllm_config
|
||||
vllm_config.additional_config = {"refresh": True}
|
||||
init_ascend_config(vllm_config)
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
|
||||
@@ -46,6 +46,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
|
||||
@@ -165,9 +165,8 @@ class AscendConfig:
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||
"SLO_limits_for_dynamic_batch", -1)
|
||||
from vllm_ascend.utils import \
|
||||
get_flashcomm2_oproj_tp_size_and_validate_config
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
|
||||
self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get(
|
||||
"enable_npugraph_ex", False)
|
||||
|
||||
@@ -34,10 +34,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz, weak_ref_tensors)
|
||||
flashcomm2_o_shared_enabled, is_enable_nz,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -848,6 +853,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
'q_b_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
@@ -858,10 +876,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.speculative_config = self.vllm_config.speculative_config
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
@@ -995,6 +1012,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(self.o_proj)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
@@ -1515,6 +1536,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
@@ -1633,6 +1658,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
|
||||
@@ -9,7 +9,8 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import enable_sp, flashcomm2_enable
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
|
||||
flashcomm2_o_shared_enabled)
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -77,6 +78,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
@@ -182,6 +184,29 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
|
||||
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
|
||||
group_ranks = []
|
||||
for pp_idx in range(global_pp_size):
|
||||
group = []
|
||||
for dp_idx in range(global_dp_size):
|
||||
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
|
||||
for i in range(global_tp_size):
|
||||
global_rank = base + i
|
||||
group.append(global_rank)
|
||||
group_ranks.append(group)
|
||||
|
||||
return init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=group_name)
|
||||
|
||||
global _SHARED_WEIGHT
|
||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if enable_sp() and is_ds_v32:
|
||||
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
|
||||
|
||||
# TODO: Extract and unify the logic across different communication group.
|
||||
if flashcomm2_enable():
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
@@ -234,17 +259,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="flashcomm2_odp")
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if enable_sp() and is_ds_v32:
|
||||
global _SHARED_WEIGHT
|
||||
group_ranks = [list(range(torch.distributed.get_world_size()))]
|
||||
_SHARED_WEIGHT = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="CP_shared_weight")
|
||||
# Create shared weight group for flashcomm2 oproj
|
||||
if flashcomm2_o_shared_enabled():
|
||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
|
||||
@@ -97,6 +97,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
|
||||
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
|
||||
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
|
||||
@@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||
|
||||
|
||||
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
vllm_config):
|
||||
def flashcomm2_o_shared_enabled() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
|
||||
|
||||
|
||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
|
||||
|
||||
if not flashcomm2_enable():
|
||||
logger.debug("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size
|
||||
flashcomm2_oproj_shared = False
|
||||
logger.info("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
|
||||
)
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
@@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
||||
"with additional support for hybrid deployment scenarios. "
|
||||
"It is not applicable in D-scenario environments.")
|
||||
if flashcomm2_oproj_shared:
|
||||
logger.info("Enable FLASHCOMM2 with oproj_shared.")
|
||||
|
||||
return flashcomm2_oproj_tp_size
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
|
||||
Reference in New Issue
Block a user