285 lines
9.6 KiB
Python
285 lines
9.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
from vllm.attention.ops.vit_attn_wrappers import (
|
|
vit_flash_attn_wrapper,
|
|
vit_torch_sdpa_wrapper,
|
|
)
|
|
from vllm.config import MultiModalConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def maybe_get_vit_flash_attn_backend(
|
|
attn_backend: AttentionBackendEnum | None,
|
|
) -> Callable | None:
|
|
# At this point,
|
|
# we already have the attn_backend,
|
|
# overriding logic is done in the platform-specific implementation.
|
|
# so we don't need to override backend here.
|
|
# Just return the attn_backend and flash_attn_varlen_func.
|
|
|
|
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
|
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
|
from aiter import flash_attn_varlen_func
|
|
else:
|
|
flash_attn_varlen_func = None
|
|
|
|
# if attn_backend is TORCH_SDPA,
|
|
# it will reach here and the flash_attn_varlen_func will be None.
|
|
return flash_attn_varlen_func
|
|
|
|
|
|
@CustomOp.register("mm_encoder_attn")
|
|
class MMEncoderAttention(CustomOp):
|
|
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float | None = None,
|
|
num_kv_heads: int | None = None,
|
|
prefix: str = "",
|
|
multimodal_config: MultiModalConfig | None = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
num_heads: number of attention heads per partition.
|
|
head_size: hidden_size per attention head.
|
|
scale: scale factor.
|
|
num_kv_heads: number of kv heads.
|
|
prefix: This has no effect, it is only here to make it easier to
|
|
swap between Attention and MultiHeadAttention
|
|
multimodal_config: configs for multi-modal.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = scale
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.layer_name = prefix
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0, (
|
|
f"num_heads ({self.num_heads}) is not "
|
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
|
)
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
# During model initialization, the default dtype is set as the model
|
|
# weight and activation dtype.
|
|
dtype = torch.get_default_dtype()
|
|
|
|
# Try to get vision attention backend from multimodal_config.
|
|
attn_backend_override = None
|
|
if multimodal_config is not None:
|
|
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
|
|
|
# Get device-specific vision attention backend.
|
|
self.attn_backend = get_vit_attn_backend(
|
|
head_size=head_size,
|
|
dtype=dtype,
|
|
attn_backend_override=attn_backend_override,
|
|
)
|
|
|
|
self.is_flash_attn_backend = self.attn_backend in {
|
|
AttentionBackendEnum.FLASH_ATTN,
|
|
AttentionBackendEnum.ROCM_AITER_FA,
|
|
}
|
|
|
|
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
|
self.attn_backend,
|
|
)
|
|
|
|
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
|
|
|
@classmethod
|
|
def enabled(cls) -> bool:
|
|
return True
|
|
|
|
def reshape_qkv_to_4d(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
bsz: int,
|
|
q_len: int,
|
|
kv_len: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Reshape query, key, value to 4D tensors:
|
|
(batch_size, seq_len, num_heads, head_size)
|
|
"""
|
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
|
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
|
# Handle MQA and GQA
|
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
|
|
|
return query, key, value
|
|
|
|
def reshape_qkv_to_3d(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
bsz: int,
|
|
q_len: int,
|
|
kv_len: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Reshape query, key, value to 3D tensors:
|
|
(batch_size * seq_len, num_heads, head_size)
|
|
"""
|
|
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
|
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
|
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
|
|
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
|
# Handle MQA and GQA
|
|
key = torch.repeat_interleave(key, num_repeat, dim=1)
|
|
value = torch.repeat_interleave(value, num_repeat, dim=1)
|
|
|
|
return query, key, value
|
|
|
|
def _forward_sdpa(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# TODO(Isotr0py): Migrate MultiHeadAttention
|
|
assert cu_seqlens is not None
|
|
|
|
bsz, q_len = query.size()[:2]
|
|
kv_len = key.size(1)
|
|
|
|
query, key, value = self.reshape_qkv_to_4d(
|
|
query, key, value, bsz, q_len, kv_len
|
|
)
|
|
|
|
output = vit_torch_sdpa_wrapper(
|
|
q=query,
|
|
k=key,
|
|
v=value,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
return output
|
|
|
|
def _forward_fa(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
assert self.flash_attn_varlen_func is not None, (
|
|
"Flash attention function is not set."
|
|
)
|
|
# # TODO(Isotr0py): Migrate MultiHeadAttention
|
|
assert cu_seqlens is not None and max_seqlen is not None
|
|
|
|
bsz = query.shape[0]
|
|
|
|
output = vit_flash_attn_wrapper(
|
|
q=query,
|
|
k=key,
|
|
v=value,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
batch_size=bsz,
|
|
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
|
)
|
|
return output
|
|
|
|
def forward_native(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
if self.is_flash_attn_backend:
|
|
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported multi-modal encoder attention backend for CUDA: "
|
|
f"{self.attn_backend}."
|
|
)
|
|
|
|
def forward_cpu(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|
|
|
|
def forward_xpu(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
assert self.is_flash_attn_backend, (
|
|
"XPU only supports FLASH_ATTN for vision attention."
|
|
)
|
|
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
|
|
|
def forward_tpu(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
|
) -> torch.Tensor:
|
|
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
|
|
f"MMEncoderAttention on TPU only supports PALLAS backend, "
|
|
f"but got {self.attn_backend}."
|
|
)
|
|
if cu_seqlens is None:
|
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
from torch_xla.experimental.custom_kernel import flash_attention
|
|
|
|
out = flash_attention(query, key, value, sm_scale=self.scale)
|
|
out = out.transpose(1, 2)
|
|
return out
|
|
logger.warning_once(
|
|
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
|
|
"Falling back to SDPA implementation.",
|
|
)
|
|
return self._forward_sdpa(query, key, value, cu_seqlens)
|