[Feature]cpu offload connector (#1659)

This PR implements cpu offload connector to enable NPU kv cache offload
to host DRAM.

- vLLM version: v0.10.2
- vLLM main:
5aeb925452

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
Co-authored-by: AlvisGong <gwly0401@163.com>
This commit is contained in:
lidenghui1110
2025-09-23 14:25:05 +08:00
committed by GitHub
parent 96eb1ed408
commit 0f3939e5a9
10 changed files with 990 additions and 44 deletions

View File

@@ -26,53 +26,21 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

View File

@@ -18,7 +18,9 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -858,8 +860,8 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -888,6 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if has_decode:
decode_q_c = q_c[:num_decode_tokens]
@@ -934,6 +938,7 @@ class AscendMLAImpl(MLAAttentionImpl):
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
@@ -960,7 +965,8 @@ class AscendMLAImpl(MLAAttentionImpl):
# MLA Preprocess
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
layer_name, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv)
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
@@ -1018,4 +1024,8 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded

View File

@@ -1,7 +1,11 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, List
import torch
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
@dataclass
@@ -100,3 +104,34 @@ def split_decodes_and_prefills(
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)