[Feat] flashcomm2+oshard Generalized (#4723)
### What this PR does / why we need it?
[FlashComm2](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf)
introduces redundant storage of the o_proj matrix, which imposes
pressure on GPU memory. We propose the FlashComm2+Oshard approach by
integrating the shared linear layer feature (#2931). This approach
distributes weights layer-by-layer to each GPU and accesses the o_proj
of each layer via asynchronous broadcast operations, thereby alleviating
memory pressure while achieving nearly lossless performance compared to
the original FlashComm2. This PR implements a generalized
FlashComm2+Oshard solution.
Using following env to support flashcomm2 with oshard
```shell
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
--additional-config '{
"layer_sharding": ["o_proj"]
}'
```
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
weak_ref_tensors)
|
||||
|
||||
@@ -349,6 +350,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.value_cache = None
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
super().process_weights_after_loading(act_dtype)
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
flashcomm2_oshard_manager.post_process_after_loading()
|
||||
|
||||
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user