提交vllm0.11.0开发分支
This commit is contained in:
@@ -1,55 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun attention wrapper for context and decode"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from itertools import accumulate
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from .utils import CommonAttentionState, CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (
|
||||
is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx,
|
||||
compute_slot_mapping,
|
||||
)
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from .utils import (CommonAttentionState, CommonMetadataBuilder)
|
||||
from vllm.attention.backends.utils import (is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx, compute_slot_mapping)
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "KUNLUN_ATTENTION"
|
||||
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
"""
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
"""
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[: self.num_prefills + 1]
|
||||
)
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[: self.num_prefills + 1]
|
||||
)
|
||||
kv_prefix_start_loc_host = (
|
||||
None
|
||||
if self.kv_prefix_start_loc_host is None
|
||||
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
|
||||
+ query_start_loc_host
|
||||
)
|
||||
kv_prefix_start_loc = (
|
||||
None
|
||||
if kv_prefix_start_loc_host is None
|
||||
else kv_prefix_start_loc_host.cuda()
|
||||
)
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
||||
self.query_start_loc_host[:self.num_prefills + 1])
|
||||
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
|
||||
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
|
||||
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: self.num_prefills]
|
||||
)
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
)
|
||||
seq_start_loc=self.seq_start_loc)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[self.num_prefill_tokens :]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[self.num_prefills :]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[self.num_prefills :]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[self.num_prefills :]
|
||||
)
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""KunlunMetadataBuilder"""
|
||||
|
||||
_metadata_cls = KunlunMetadata
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
self.prefix_cache_kv_lens: List[int] = []
|
||||
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""prepare"""
|
||||
super().prepare()
|
||||
self.prefix_cache_kv_lens = list()
|
||||
|
||||
def _add_seq_group(
|
||||
self,
|
||||
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool,
|
||||
):
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (
|
||||
seq_id,
|
||||
token_len,
|
||||
seq_len,
|
||||
curr_seq_len,
|
||||
query_len,
|
||||
context_len,
|
||||
curr_sliding_window_block,
|
||||
) in zip(
|
||||
inter_data.seq_ids,
|
||||
[len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens,
|
||||
inter_data.seq_lens,
|
||||
inter_data.query_lens,
|
||||
inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
):
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(placeholders)
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len
|
||||
)
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
block_table = []
|
||||
assert (
|
||||
not chunked_prefill_enabled
|
||||
), "chunk prefill not supported for kunlun attention"
|
||||
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][:context_len // self.block_size]
|
||||
elif ((not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
if is_prompt:
|
||||
self.prefix_cache_kv_lens.append(context_len)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window
|
||||
)
|
||||
compute_slot_mapping(
|
||||
is_profile_run,
|
||||
self.slot_mapping,
|
||||
seq_id,
|
||||
seq_len,
|
||||
context_len,
|
||||
start_idx,
|
||||
self.block_size,
|
||||
inter_data.block_tables,
|
||||
)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int,
|
||||
):
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""build"""
|
||||
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
query_start_loc_host = torch.tensor(
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
# 包含kv cache ,且存在命中的情况
|
||||
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
|
||||
self.prefix_cache_kv_lens = list(
|
||||
accumulate(self.prefix_cache_kv_lens, initial=0)
|
||||
)
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(
|
||||
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
|
||||
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
|
||||
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
|
||||
return attn_meta
|
||||
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: KunlunMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
"""
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
None,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if attn_type == AttentionType.ENCODER and (
|
||||
not attn_metadata.is_all_encoder_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder attention requires setting " "encoder metadata attributes."
|
||||
)
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER and (
|
||||
not attn_metadata.is_all_cross_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes."
|
||||
)
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
value = value.contiguous()
|
||||
KunlunOps.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
KunlunOps.reshape_and_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
out = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.query_start_loc_host,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
).view_as(query)
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(query)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
Reference in New Issue
Block a user