[Feature]cpu offload connector (#1659)
This PR implements cpu offload connector to enable NPU kv cache offload
to host DRAM.
- vLLM version: v0.10.2
- vLLM main:
5aeb925452
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
Co-authored-by: AlvisGong <gwly0401@163.com>
This commit is contained in:
@@ -18,7 +18,9 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
@@ -858,8 +860,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv):
|
||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||
attn_metadata, need_gather_q_kv):
|
||||
# MLA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
@@ -888,6 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
@@ -934,6 +938,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_name,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: M,
|
||||
@@ -960,7 +965,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
# MLA Preprocess
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# MLA Preprocess for decoding
|
||||
@@ -1018,4 +1024,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user