from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional import torch from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl from vllm import _custom_ops as ops from vllm.attention.ops.paged_attn import PagedAttention def move_cache( backend, kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, kv_cache_dtype: str, num_kv_heads: int, head_size: int, ) -> None: if backend.get_name() == "rocm-flash-attn" or \ backend.get_name() == "xformers": key_caches = [] value_caches = [] num_layers = len(kv_caches) token_num = src_to_dists.shape[0] tmp_store_kv = torch.empty( (2, num_layers, token_num, num_kv_heads, head_size), dtype=kv_caches[0].dtype, device=kv_caches[0].device) keys = tmp_store_kv[0].contiguous() values = tmp_store_kv[1].contiguous() for kv_cache in kv_caches: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, num_kv_heads, head_size) key_caches.append(key_cache) value_caches.append(value_cache) ops.read_cache( keys, values, key_caches, value_caches, src_to_dists[:, 0].contiguous(), kv_cache_dtype ) ops.write_cache_multi_layers( keys, values, key_caches, value_caches, src_to_dists[:, 1].contiguous(), kv_cache_dtype ) else: raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")