87 lines
3.5 KiB
Python
87 lines
3.5 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.models.utils import extract_layer_index
|
|
from vllm.platforms import current_platform
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.layer import Attention
|
|
|
|
|
|
def bind_kv_cache(
|
|
kv_caches: dict[str, torch.Tensor],
|
|
forward_context: dict[str, "Attention"],
|
|
runner_kv_caches: list[torch.Tensor],
|
|
num_attn_module: Optional[int] = 1,
|
|
) -> None:
|
|
"""
|
|
Bind the allocated KV cache to both ModelRunner and forward context so
|
|
that the KV cache can be used in the forward pass.
|
|
|
|
This function:
|
|
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
|
kv_caches.
|
|
2) Associates each attention layer in the `forward_context` with its
|
|
corresponding KV cache in kv_caches.
|
|
|
|
Args:
|
|
kv_caches: The allocated kv_caches with layer names as keys.
|
|
forward_context: The global forward context containing all Attention
|
|
layers with layer names as keys.
|
|
runner_kv_caches: The kv_cache declared by ModelRunner.
|
|
"""
|
|
# Bind kv_caches to ModelRunner
|
|
assert len(runner_kv_caches) == 0
|
|
|
|
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
|
index2name = defaultdict(list)
|
|
for layer_name in kv_caches:
|
|
index2name[extract_layer_index(layer_name,
|
|
num_attn_module)].append(layer_name)
|
|
|
|
for layer_index in sorted(index2name.keys()):
|
|
layer_names = index2name[layer_index]
|
|
if len(layer_names) > 1:
|
|
# One typical case is encoder-decoder model, e.g., bart.
|
|
# The cross attention and self attention in the same decoder layer
|
|
# has different layer_name but the same layer_index.
|
|
|
|
# TODO - analyze where runner_kv_caches is used and the right
|
|
# way to ensure it properly reflects multiple attention layers
|
|
# in the same decoder block.
|
|
if current_platform.is_cuda() or current_platform.is_xpu(
|
|
) or current_platform.is_supa():
|
|
# We know that the GPU runner is not impacted by this
|
|
# case. Some test code depends on runner_kv_caches, but
|
|
# not in a way that's impacted by ignoring this.
|
|
pass
|
|
else:
|
|
raise NotImplementedError
|
|
layer_name = layer_names[0]
|
|
runner_kv_caches.append(kv_caches[layer_name])
|
|
|
|
# Bind kv_caches to forward context
|
|
for layer_name, kv_cache in kv_caches.items():
|
|
# NOTE: Use list because of v0 PP virtual engine.
|
|
forward_context[layer_name].kv_cache = [kv_cache]
|