Sync from v0.13
This commit is contained in:
0
vllm/attention/layers/__init__.py
Normal file
0
vllm/attention/layers/__init__.py
Normal file
120
vllm/attention/layers/chunked_local_attention.py
Normal file
120
vllm/attention/layers/chunked_local_attention.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.NEVER
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=ChunkedLocalAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
assert self.attention_chunk_size
|
||||
return ChunkedLocalAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
178
vllm/attention/layers/cross_attention.py
Normal file
178
vllm/attention/layers/cross_attention.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(
|
||||
encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
kv_cache_spec: CrossAttentionSpec,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Get cross-attention slot mappings."""
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
slot_mappings = []
|
||||
|
||||
# Find indices with non-zero encoder sequence lengths
|
||||
# The majority of parallel requests will be running the
|
||||
# decoder, so this list should be relatively small.
|
||||
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||
|
||||
for req_index in active_indices:
|
||||
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||
|
||||
# Calculate the number of blocks needed for this request
|
||||
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||
|
||||
# Get the block IDs for this request from the tensor
|
||||
req_block_ids = block_table_tensor[req_index]
|
||||
|
||||
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||
|
||||
# All needed blocks are allocated
|
||||
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
|
||||
block_indices = i_values // block_size
|
||||
block_offsets = i_values % block_size
|
||||
block_numbers = needed_block_ids[block_indices]
|
||||
slot_mapping = block_numbers * block_size + block_offsets
|
||||
|
||||
slot_mappings.append(slot_mapping)
|
||||
|
||||
if slot_mappings:
|
||||
return torch.cat(slot_mappings)
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||
num_cache_decodes = (
|
||||
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
|
||||
)
|
||||
if num_cache_decodes > 0:
|
||||
# CrossAttn KV cache has already been populated on first decoder step,
|
||||
# skip slot_mapping calculation for requests that do not need
|
||||
# reshape_and_cache.
|
||||
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
|
||||
new_metadata.encoder_seq_lens_cpu = np.where(
|
||||
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# seq_lens is provided by model runner: initial encoder input length is
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
)
|
||||
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=CrossAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
"""
|
||||
Cross-attention for encoder-decoder models.
|
||||
Handles attention between decoder queries and encoder keys/values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return CrossAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
103
vllm/attention/layers/encoder_only_attention.py
Normal file
103
vllm/attention/layers/encoder_only_attention.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, (
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Does not need KV cache
|
||||
return None
|
||||
284
vllm/attention/layers/mm_encoder_attention.py
Normal file
284
vllm/attention/layers/mm_encoder_attention.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
)
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: AttentionBackendEnum | None,
|
||||
) -> Callable | None:
|
||||
# At this point,
|
||||
# we already have the attn_backend,
|
||||
# overriding logic is done in the platform-specific implementation.
|
||||
# so we don't need to override backend here.
|
||||
# Just return the attn_backend and flash_attn_varlen_func.
|
||||
|
||||
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
# if attn_backend is TORCH_SDPA,
|
||||
# it will reach here and the flash_attn_varlen_func will be None.
|
||||
return flash_attn_varlen_func
|
||||
|
||||
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of attention heads per partition.
|
||||
head_size: hidden_size per attention head.
|
||||
scale: scale factor.
|
||||
num_kv_heads: number of kv heads.
|
||||
prefix: This has no effect, it is only here to make it easier to
|
||||
swap between Attention and MultiHeadAttention
|
||||
multimodal_config: configs for multi-modal.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Try to get vision attention backend from multimodal_config.
|
||||
attn_backend_override = None
|
||||
if multimodal_config is not None:
|
||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||
|
||||
# Get device-specific vision attention backend.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def reshape_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 4D tensors:
|
||||
(batch_size, seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def reshape_qkv_to_3d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 3D tensors:
|
||||
(batch_size * seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=1)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query, key, value = self.reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.flash_attn_varlen_func is not None, (
|
||||
"Flash attention function is not set."
|
||||
)
|
||||
# # TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None and max_seqlen is not None
|
||||
|
||||
bsz = query.shape[0]
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for CUDA: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.is_flash_attn_backend, (
|
||||
"XPU only supports FLASH_ATTN for vision attention."
|
||||
)
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
|
||||
f"MMEncoderAttention on TPU only supports PALLAS backend, "
|
||||
f"but got {self.attn_backend}."
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
return out
|
||||
logger.warning_once(
|
||||
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
|
||||
"Falling back to SDPA implementation.",
|
||||
)
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
Reference in New Issue
Block a user