Added support for KV connector v1 (#2039)

### What this PR does / why we need it?
- This PR adds the support for the KV connector interface in the V1
architecture, in the same way as vllm. Vllm-ascend currently lacks of
this support, required to support also layerwise management of KV
caches.

- The connector interface allows using external tools and integrate them
with vllm

### Notes:
We are aware of Issue #684 , however that issue does not modify the
attention classes as necessary to perform layerwise management of KV
caches required for connectors like LMCache.

The implementation of this PR ported the necessary code from the vanilla
vllm. The KV connector API is the same as vanilla vllm, supporting the
standard KV connector API.

EDIT: this PR was re-implementing part of the changes merged one hour
before this PR was made on the file model_runner_v1.py. I solved the
conflicts by removing any modification to the model_runner_v1 file,
which now are largely already merged in main. Now this PR is left for
the modifications to the attention_v1 file.

### Does this PR introduce _any_ user-facing change?
The PR does not modify current APIs, but it extends the behavior of
current worker runner and attention classes to save and load KV caches.
In absence of connectors, the behavior should stay untouched.

### How was this patch tested?
- No unit test implemented yet for the worker.

- Tested together with LMCache using
https://github.com/LMCache/LMCache/blob/dev/examples/kv_cache_reuse/local_backends/offload.py
with the following models:
1 Deepseek-R1-Distill-Qwen-1.5B
2 Qwen3-30B-A3B
3 Deepseek-v2-lite
4 Llama-3.1-8B
LMCache used in both layerwise and non-layerwise mode.

- Performed LMEval on LMCache integrated with vllm-ascend.

Results without LMCache on Qwen3-8B:
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8400|± |0.0101|
| | |strict-match | 5|exact_match|↑ |0.8355|± |0.0102|
 
Results with LMCache Layerwise:
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8385|± |0.0101|
| | |strict-match | 5|exact_match|↑ |0.8332|± |0.0103|


- vLLM version: v0.10.1.1
- vLLM main:
50fede6634

---------

Signed-off-by: marcobarlo <barlettamarco8@gmail.com>
Signed-off-by: marcobarlo <65128997+marcobarlo@users.noreply.github.com>
This commit is contained in:
Marco Barletta
2025-09-08 03:04:22 +02:00
committed by GitHub
parent 2967e5e22a
commit 6666e5265d

View File

@@ -26,6 +26,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
@@ -37,6 +40,37 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
from vllm_ascend.worker.npu_input_batch import InputBatch
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -537,6 +571,7 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
@@ -549,6 +584,7 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return