init src 0.9.2
This commit is contained in:
55
vllm/attention/backends/tree_decoding_utils.py
Normal file
55
vllm/attention/backends/tree_decoding_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
|
||||
def move_cache(
|
||||
backend,
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> None:
|
||||
if backend.get_name() == "rocm-flash-attn" or \
|
||||
backend.get_name() == "xformers":
|
||||
|
||||
key_caches = []
|
||||
value_caches = []
|
||||
|
||||
num_layers = len(kv_caches)
|
||||
token_num = src_to_dists.shape[0]
|
||||
|
||||
tmp_store_kv = torch.empty(
|
||||
(2, num_layers, token_num, num_kv_heads, head_size),
|
||||
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
||||
keys = tmp_store_kv[0].contiguous()
|
||||
values = tmp_store_kv[1].contiguous()
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, num_kv_heads, head_size)
|
||||
key_caches.append(key_cache)
|
||||
value_caches.append(value_cache)
|
||||
|
||||
ops.read_cache(
|
||||
keys,
|
||||
values,
|
||||
key_caches,
|
||||
value_caches,
|
||||
src_to_dists[:, 0].contiguous(),
|
||||
kv_cache_dtype
|
||||
)
|
||||
|
||||
ops.write_cache_multi_layers(
|
||||
keys,
|
||||
values,
|
||||
key_caches,
|
||||
value_caches,
|
||||
src_to_dists[:, 1].contiguous(),
|
||||
kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
|
||||
Reference in New Issue
Block a user