提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -1,55 +1,28 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
# Email: baoqian@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun attention wrapper for context and decode"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from itertools import accumulate
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from .utils import CommonAttentionState, CommonMetadataBuilder
from vllm.attention.backends.utils import (
is_block_tables_empty,
compute_slot_mapping_start_idx,
compute_slot_mapping,
)
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from .utils import (CommonAttentionState, CommonMetadataBuilder)
from vllm.attention.backends.utils import (is_block_tables_empty,
compute_slot_mapping_start_idx, compute_slot_mapping)
from vllm_kunlun.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm.attention.backends.abstract import AttentionLayer
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d
logger = init_logger(__name__)
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
accept_output_buffer = False
@staticmethod
def get_name() -> str:
return "KUNLUN_ATTENTION"
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
@property
def is_all_encoder_attn_metadata_set(self):
"""
'''
All attention metadata required for encoder attention is set.
"""
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
"""
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# metadata structure
return self._cached_prefill_metadata
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (
None
if self.query_start_loc is None
else self.query_start_loc[: self.num_prefills + 1]
)
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
# flash attention needs both lod information on host and device
query_start_loc_host = (
None
if self.query_start_loc_host is None
else self.query_start_loc_host[: self.num_prefills + 1]
)
kv_prefix_start_loc_host = (
None
if self.kv_prefix_start_loc_host is None
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
+ query_start_loc_host
)
kv_prefix_start_loc = (
None
if kv_prefix_start_loc_host is None
else kv_prefix_start_loc_host.cuda()
)
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[: self.num_prefill_tokens]
)
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[: self.num_prefills]
)
context_lens_tensor = (
None
if self.context_lens_tensor is None
else self.context_lens_tensor[: self.num_prefills]
)
query_start_loc_host = (None if self.query_start_loc_host is None else
self.query_start_loc_host[:self.num_prefills + 1])
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
# for prefix cache, block table only contains blocks that hit
# if self.block_tables is None:
# block_tables = None
# elif self.block_tables.shape[1] == 0:
# block_tables = self.block_tables[:self.num_prefills]
# else:
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
block_tables = (
None
if self.block_tables is None
else self.block_tables[: self.num_prefills]
)
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
seq_start_loc=self.seq_start_loc,
)
seq_start_loc=self.seq_start_loc)
return self._cached_prefill_metadata
@property
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[self.num_prefill_tokens :]
)
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[self.num_prefills :]
)
seq_lens_tensor_cpu = (
None
if self.seq_lens_tensor_cpu is None
else self.seq_lens_tensor_cpu[self.num_prefills :]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[self.num_prefills :]
)
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
)
enable_kv_scales_calculation=False)
return self._cached_decode_metadata
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
"""KunlunMetadataBuilder"""
_metadata_cls = KunlunMetadata
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
super().__init__(input_builder)
self.prefix_cache_kv_lens: List[int] = []
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
"""prepare"""
super().prepare()
self.prefix_cache_kv_lens = list()
def _add_seq_group(
self,
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool,
):
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (
seq_id,
token_len,
seq_len,
curr_seq_len,
query_len,
context_len,
curr_sliding_window_block,
) in zip(
inter_data.seq_ids,
[len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens,
inter_data.seq_lens,
inter_data.query_lens,
inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
):
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(placeholders)
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert (
query_len == 1
), "seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len
)
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
block_table = []
assert (
not chunked_prefill_enabled
), "chunk prefill not supported for kunlun attention"
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
if inter_data.prefix_cache_hit:
assert context_len != 0
assert context_len % self.block_size == 0
block_table = block_tables[seq_id][: context_len // self.block_size]
elif (not is_prompt) and block_tables is not None:
# block_table = block_tables[seq_id]
block_table = block_tables[seq_id][:context_len // self.block_size]
elif ((not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][-curr_sliding_window_block:]
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if is_prompt:
self.prefix_cache_kv_lens.append(context_len)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window
)
compute_slot_mapping(
is_profile_run,
self.slot_mapping,
seq_id,
seq_len,
context_len,
start_idx,
self.block_size,
inter_data.block_tables,
)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(
self,
seq_lens: List[int],
query_lens: List[int],
cuda_graph_pad_size: int,
batch_size: int,
):
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""build"""
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
query_start_loc = list(accumulate(query_lens, initial=0))
query_start_loc_host = torch.tensor(
query_start_loc, dtype=torch.int32, device="cpu"
)
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
attn_meta.query_start_loc_host = query_start_loc_host
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
# If kv cache is included and there is a hit
# 包含kv cache ,且存在命中的情况
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
self.prefix_cache_kv_lens = list(
accumulate(self.prefix_cache_kv_lens, initial=0)
)
prefix_cache_kv_lens_tensor = torch.tensor(
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
)
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
return attn_meta
def _get_seq_len_block_table_args(
attn_metadata: KunlunMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
"""
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables,
)
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
None,
)
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl"""
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError("kunlunAttention does not support block-sparse attention.")
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}."
)
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if attn_type == AttentionType.ENCODER and (
not attn_metadata.is_all_encoder_attn_metadata_set
):
raise AttributeError(
"Encoder attention requires setting " "encoder metadata attributes."
)
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif attn_type == AttentionType.ENCODER_DECODER and (
not attn_metadata.is_all_cross_attn_metadata_set
):
raise AttributeError(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
else:
updated_slot_mapping = attn_metadata.slot_mapping
value = value.contiguous()
KunlunOps.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
)
KunlunOps.reshape_and_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype)
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
out = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,
prefill_meta.query_start_loc_host,
query,
key,
value,
alibi_slopes=self.alibi_slopes,
).view_as(query)
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
alibi_slopes=self.alibi_slopes).view_as(query)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
assert (
attn_type != AttentionType.ENCODER_ONLY
), "Encoder-only models should not have decode metadata."
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
(
seq_lens_arg,
max_seq_len_arg,
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
return output.view(-1, self.num_heads * self.head_size)

View File

@@ -4,13 +4,12 @@ import torch
import torch.nn.functional as F
from typing import Optional, List, Dict, Any
from vllm.attention import AttentionType
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.forward_context import ForwardContext, get_forward_context
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
from vllm.platforms import _Backend
class Attention(VllmAttention):
"""Attention"""
def __init__(
self,
num_heads: int,
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output_kunlun(
query, key, value, output, self.layer_name
)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
else:
return unified_attention(query, key, value, self.layer_name)
return unified_attention(
query, key, value, self.layer_name)
#
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
class MultiHeadAttention(VllmMultiHeadAttention):
def __init__(
self,
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
num_kv_heads: Optional[int] = None,
):
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
num_heads = num_heads,
head_size = head_size,
scale = scale,
num_kv_heads = num_kv_heads,
)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
self.attn_backend = _Backend.FLASH_ATTN
def forward(
self,
query: torch.Tensor,
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
"""wait_for_kv_layer_from_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str, kv_cache_layer: List[torch.Tensor]
):
layer_name: str,
kv_cache_layer: List[torch.Tensor]):
"""maybe_save_kv_layer_to_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
def unified_attention_with_output_kunlun(
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def _fake_unified_attention_with_output_kunlun(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
return None
unified_attention_with_output_kunlun.register_fake(
_fake_unified_attention_with_output_kunlun
)
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
def unified_attention(
query: torch.Tensor,
@@ -268,7 +269,8 @@ def unified_attention(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
return output