From 6666e5265d40ecafc3cb377233fee840d7fe553b Mon Sep 17 00:00:00 2001 From: Marco Barletta <65128997+marcobarlo@users.noreply.github.com> Date: Mon, 8 Sep 2025 03:04:22 +0200 Subject: [PATCH] Added support for KV connector v1 (#2039) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? - This PR adds the support for the KV connector interface in the V1 architecture, in the same way as vllm. Vllm-ascend currently lacks of this support, required to support also layerwise management of KV caches. - The connector interface allows using external tools and integrate them with vllm ### Notes: We are aware of Issue #684 , however that issue does not modify the attention classes as necessary to perform layerwise management of KV caches required for connectors like LMCache. The implementation of this PR ported the necessary code from the vanilla vllm. The KV connector API is the same as vanilla vllm, supporting the standard KV connector API. EDIT: this PR was re-implementing part of the changes merged one hour before this PR was made on the file model_runner_v1.py. I solved the conflicts by removing any modification to the model_runner_v1 file, which now are largely already merged in main. Now this PR is left for the modifications to the attention_v1 file. ### Does this PR introduce _any_ user-facing change? The PR does not modify current APIs, but it extends the behavior of current worker runner and attention classes to save and load KV caches. In absence of connectors, the behavior should stay untouched. ### How was this patch tested? - No unit test implemented yet for the worker. - Tested together with LMCache using https://github.com/LMCache/LMCache/blob/dev/examples/kv_cache_reuse/local_backends/offload.py with the following models: 1 Deepseek-R1-Distill-Qwen-1.5B 2 Qwen3-30B-A3B 3 Deepseek-v2-lite 4 Llama-3.1-8B LMCache used in both layerwise and non-layerwise mode. - Performed LMEval on LMCache integrated with vllm-ascend. Results without LMCache on Qwen3-8B: |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8400|± |0.0101| | | |strict-match | 5|exact_match|↑ |0.8355|± |0.0102| Results with LMCache Layerwise: |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8385|± |0.0101| | | |strict-match | 5|exact_match|↑ |0.8332|± |0.0103| - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/50fede6634a997f4e971ecb4eb4cce337340e394 --------- Signed-off-by: marcobarlo Signed-off-by: marcobarlo <65128997+marcobarlo@users.noreply.github.com> --- vllm_ascend/attention/attention_v1.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0915cc3..adc6c01 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,6 +26,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput @@ -37,6 +40,37 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, from vllm_ascend.worker.npu_input_batch import InputBatch +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -537,6 +571,7 @@ def unified_ascend_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] @@ -549,6 +584,7 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return