提交vllm0.11.0开发分支
This commit is contained in:
@@ -1,55 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun attention wrapper for context and decode"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from itertools import accumulate
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from .utils import CommonAttentionState, CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (
|
||||
is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx,
|
||||
compute_slot_mapping,
|
||||
)
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from .utils import (CommonAttentionState, CommonMetadataBuilder)
|
||||
from vllm.attention.backends.utils import (is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx, compute_slot_mapping)
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "KUNLUN_ATTENTION"
|
||||
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
"""
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
"""
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[: self.num_prefills + 1]
|
||||
)
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[: self.num_prefills + 1]
|
||||
)
|
||||
kv_prefix_start_loc_host = (
|
||||
None
|
||||
if self.kv_prefix_start_loc_host is None
|
||||
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
|
||||
+ query_start_loc_host
|
||||
)
|
||||
kv_prefix_start_loc = (
|
||||
None
|
||||
if kv_prefix_start_loc_host is None
|
||||
else kv_prefix_start_loc_host.cuda()
|
||||
)
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
||||
self.query_start_loc_host[:self.num_prefills + 1])
|
||||
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
|
||||
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
|
||||
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: self.num_prefills]
|
||||
)
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
)
|
||||
seq_start_loc=self.seq_start_loc)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[self.num_prefill_tokens :]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[self.num_prefills :]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[self.num_prefills :]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[self.num_prefills :]
|
||||
)
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""KunlunMetadataBuilder"""
|
||||
|
||||
_metadata_cls = KunlunMetadata
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
self.prefix_cache_kv_lens: List[int] = []
|
||||
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""prepare"""
|
||||
super().prepare()
|
||||
self.prefix_cache_kv_lens = list()
|
||||
|
||||
def _add_seq_group(
|
||||
self,
|
||||
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool,
|
||||
):
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (
|
||||
seq_id,
|
||||
token_len,
|
||||
seq_len,
|
||||
curr_seq_len,
|
||||
query_len,
|
||||
context_len,
|
||||
curr_sliding_window_block,
|
||||
) in zip(
|
||||
inter_data.seq_ids,
|
||||
[len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens,
|
||||
inter_data.seq_lens,
|
||||
inter_data.query_lens,
|
||||
inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
):
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(placeholders)
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len
|
||||
)
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
block_table = []
|
||||
assert (
|
||||
not chunked_prefill_enabled
|
||||
), "chunk prefill not supported for kunlun attention"
|
||||
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][:context_len // self.block_size]
|
||||
elif ((not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
if is_prompt:
|
||||
self.prefix_cache_kv_lens.append(context_len)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window
|
||||
)
|
||||
compute_slot_mapping(
|
||||
is_profile_run,
|
||||
self.slot_mapping,
|
||||
seq_id,
|
||||
seq_len,
|
||||
context_len,
|
||||
start_idx,
|
||||
self.block_size,
|
||||
inter_data.block_tables,
|
||||
)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int,
|
||||
):
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""build"""
|
||||
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
query_start_loc_host = torch.tensor(
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
# 包含kv cache ,且存在命中的情况
|
||||
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
|
||||
self.prefix_cache_kv_lens = list(
|
||||
accumulate(self.prefix_cache_kv_lens, initial=0)
|
||||
)
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(
|
||||
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
|
||||
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
|
||||
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
|
||||
return attn_meta
|
||||
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: KunlunMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
"""
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
None,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if attn_type == AttentionType.ENCODER and (
|
||||
not attn_metadata.is_all_encoder_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder attention requires setting " "encoder metadata attributes."
|
||||
)
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER and (
|
||||
not attn_metadata.is_all_cross_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes."
|
||||
)
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
value = value.contiguous()
|
||||
KunlunOps.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
KunlunOps.reshape_and_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
out = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.query_start_loc_host,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
).view_as(query)
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(query)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
@@ -4,13 +4,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Dict, Any
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
|
||||
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
|
||||
class Attention(VllmAttention):
|
||||
"""Attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output_kunlun(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
return unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
|
||||
#
|
||||
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
|
||||
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
|
||||
class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads = num_heads,
|
||||
head_size = head_size,
|
||||
scale = scale,
|
||||
num_kv_heads = num_kv_heads,
|
||||
)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
"""wait_for_kv_layer_from_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str, kv_cache_layer: List[torch.Tensor]
|
||||
):
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor]):
|
||||
"""maybe_save_kv_layer_to_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
|
||||
def unified_attention_with_output_kunlun(
|
||||
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def _fake_unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
return None
|
||||
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(
|
||||
_fake_unified_attention_with_output_kunlun
|
||||
)
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -268,7 +269,8 @@ def unified_attention(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
return output
|
||||
Reference in New Issue
Block a user