import torch import vllm.v1.worker.utils as utils from vllm.v1.worker.utils import defaultdict, extract_layer_index from vllm_ascend.utils import vllm_version_is if vllm_version_is("v0.15.0"): from vllm.attention.layer import Attention # type: ignore else: from vllm.model_executor.layers.attention import Attention # Without this patch, it will raise an exception when initialize kv_cache. # TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError. def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, Attention], runner_kv_caches: list[torch.Tensor], num_attn_module: int = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so that the KV cache can be used in the forward pass. This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner assert len(runner_kv_caches) == 0 # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] # remove some codes for the typical case of encoder-decoder model, e.g., bart. layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] utils.bind_kv_cache = bind_kv_cache