[Feature]cpu offload connector (#1659)

This PR implements cpu offload connector to enable NPU kv cache offload
to host DRAM.

- vLLM version: v0.10.2
- vLLM main:
5aeb925452

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
Co-authored-by: AlvisGong <gwly0401@163.com>
This commit is contained in:
lidenghui1110
2025-09-23 14:25:05 +08:00
committed by GitHub
parent 96eb1ed408
commit 0f3939e5a9
10 changed files with 990 additions and 44 deletions

View File

@@ -18,7 +18,9 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -858,8 +860,8 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -888,6 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if has_decode:
decode_q_c = q_c[:num_decode_tokens]
@@ -934,6 +938,7 @@ class AscendMLAImpl(MLAAttentionImpl):
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
@@ -960,7 +965,8 @@ class AscendMLAImpl(MLAAttentionImpl):
# MLA Preprocess
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
layer_name, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv)
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
@@ -1018,4 +1024,8 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded