提交vllm0.11.0开发分支
This commit is contained in:
@@ -24,7 +24,6 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
@@ -53,18 +52,18 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
"""
|
||||
Get the shape of the KV cache. Returns different shapes based on whether the computation is on-chip.
|
||||
If on-chip (is_kunlun() is True), returns shape (2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
Otherwise, returns shape (2, num_blocks, block_size * num_kv_heads * head_size).
|
||||
|
||||
获取KV缓存的形状,根据是否在芯片上进行计算返回不同的形状。
|
||||
如果在芯片上(is_kunlun()为True),则返回形状(2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
否则,返回形状(2, num_blocks, block_size * num_kv_heads * head_size)。
|
||||
|
||||
Args:
|
||||
num_blocks (int): The number of blocks.
|
||||
block_size (int): The size of each block.
|
||||
num_kv_heads (int): The number of KV heads.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
num_blocks (int): 块数量。
|
||||
block_size (int): 每个块大小。
|
||||
num_kv_heads (int): KV头数量。
|
||||
head_size (int): 每个头大小。
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The shape of the KV cache, including two elements: the first element is 2, indicating the number of dimensions is 2; the second element is one of num_blocks, num_kv_heads, block_size, and head_size.
|
||||
Tuple[int, ...]: KV缓存的形状,包括两个元素:第一个元素为2,表示维度数量为2;第二个元素为num_blocks、num_kv_heads、block_size和head_size中的任意一个。
|
||||
"""
|
||||
if current_platform.is_kunlun():
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
@@ -77,20 +76,20 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split a cached tensor (containing key and value) into two parts, each part is a tensor.
|
||||
If running on KUNLUN, the first returned tensor is the key cache, and the second tensor is the value cache.
|
||||
Otherwise, the first tensor is the key cache, and the second tensor is a view of the key cache with shape (num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
and the third tensor is the value cache with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
|
||||
将一个缓存张量(包含key和value)分成两部分,每个部分是一个张量。
|
||||
如果在KUNLUN上运行,则返回的第一个张量是key缓存,第二个张量是value缓存。
|
||||
否则,第一个张量是key缓存,第二个张量是key缓存的view,其形状为(num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
第三个张量是value缓存,其形状为(num_blocks, num_kv_heads, head_size, -1)。
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): A tensor containing key and value, with shape (2, num_blocks, kv_cache_size).
|
||||
num_kv_heads (int): The number of heads in multi-head attention.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
kv_cache (torch.Tensor): 包含key和value的张量,形状为(2, num_blocks, kv_cache_size)。
|
||||
num_kv_heads (int): 多头注意力中的头数。
|
||||
head_size (int): 每个头的大小。
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- key_cache (torch.Tensor): A tensor containing the key cache, with shape (num_blocks, num_kv_heads, head_size//x, -1, x).
|
||||
- value_cache (torch.Tensor): A tensor containing the value cache, with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
- key_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size//x, -1, x),包含key缓存。
|
||||
- value_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size, -1),包含value缓存。
|
||||
"""
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
@@ -100,7 +99,8 @@ class PagedAttention:
|
||||
value_cache = kv_cache[1]
|
||||
else:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
@@ -152,17 +152,16 @@ class PagedAttention:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (
|
||||
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||
), (
|
||||
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables."
|
||||
)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
@@ -170,10 +169,9 @@ class PagedAttention:
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = max_seq_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
@@ -302,4 +300,4 @@ class PagedAttention:
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
Reference in New Issue
Block a user