[Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it?
It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.
This PR depends on #3232 and #2931
### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>
### Does this PR introduce _any_ user-facing change?
Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -34,10 +34,15 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz, weak_ref_tensors)
|
||||
flashcomm2_o_shared_enabled, is_enable_nz,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -848,6 +853,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
'q_b_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
@@ -858,10 +876,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.speculative_config = self.vllm_config.speculative_config
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
@@ -995,6 +1012,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(self.o_proj)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
@@ -1515,6 +1536,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
@@ -1633,6 +1658,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
|
||||
Reference in New Issue
Block a user