55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
|
|
import torch
|
|
|
|
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
|
|
from vllm import _custom_ops as ops
|
|
from vllm.attention.ops.paged_attn import PagedAttention
|
|
|
|
def move_cache(
|
|
backend,
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
kv_cache_dtype: str,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> None:
|
|
if backend.get_name() == "rocm-flash-attn" or \
|
|
backend.get_name() == "xformers":
|
|
|
|
key_caches = []
|
|
value_caches = []
|
|
|
|
num_layers = len(kv_caches)
|
|
token_num = src_to_dists.shape[0]
|
|
|
|
tmp_store_kv = torch.empty(
|
|
(2, num_layers, token_num, num_kv_heads, head_size),
|
|
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
|
keys = tmp_store_kv[0].contiguous()
|
|
values = tmp_store_kv[1].contiguous()
|
|
|
|
for kv_cache in kv_caches:
|
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
|
kv_cache, num_kv_heads, head_size)
|
|
key_caches.append(key_cache)
|
|
value_caches.append(value_cache)
|
|
|
|
ops.read_cache(
|
|
keys,
|
|
values,
|
|
key_caches,
|
|
value_caches,
|
|
src_to_dists[:, 0].contiguous(),
|
|
kv_cache_dtype
|
|
)
|
|
|
|
ops.write_cache_multi_layers(
|
|
keys,
|
|
values,
|
|
key_caches,
|
|
value_caches,
|
|
src_to_dists[:, 1].contiguous(),
|
|
kv_cache_dtype
|
|
)
|
|
else:
|
|
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!") |