[Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it?
It is mentioned in the [flashcomm2 technical
report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
that FC2 will introduce full redundant storage of the o_proj matrix,
which will put pressure on the memory. Therefore, the technical report
proposed a compromise solution using otp2, but it will introduce
additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing
weights layer by layer to each card, avoiding the need for TP splitting,
and can solve the memory issue.
This PR depends on #3232 and #2931
### Flashcomm2 flowchart
<img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39"
src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0"
/>
### Does this PR introduce _any_ user-facing change?
Use environment variables
```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <2783294813@qq.com>
Co-authored-by: zzh02232027 <zzh02232027@antgroup.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
||||
AscendMLADecodeMetadata,
|
||||
@@ -845,6 +846,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
model_config.dtype = torch.float16
|
||||
vllm_config.model_config = model_config
|
||||
get_current_vllm_config.return_value = vllm_config
|
||||
vllm_config.additional_config = {"refresh": True}
|
||||
init_ascend_config(vllm_config)
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
|
||||
@@ -46,6 +46,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
|
||||
Reference in New Issue
Block a user