################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from typing import TYPE_CHECKING, Optional import torch from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.attention.layer import Attention def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches: list[torch.Tensor], num_attn_module: Optional[int] = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so that the KV cache can be used in the forward pass. This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner assert len(runner_kv_caches) == 0 # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. if current_platform.is_cuda() or current_platform.is_xpu( ) or current_platform.is_supa(): # We know that the GPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this. pass else: raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache]