# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from copy import copy from typing import Optional import torch from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, subclass_attention_backend) @functools.lru_cache def create_encoder_only_attention_backend( underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: prefix = "EncoderOnlyAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> AttentionMetadata: new_common_attn_metadata = copy(common_attn_metadata) new_common_attn_metadata.causal = False return super().build(common_prefix_len, new_common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, builder_cls=EncoderOnlyAttentionBuilder) return attn_backend class EncoderOnlyAttention(Attention): """ Encoder attention is a special case that doesn't need a KV Cache. """ def __init__(self, num_heads: int, head_size: int, scale: float, cache_config: Optional[CacheConfig] = None, attn_type: Optional[str] = None, **kwargs): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size else: kv_cache_dtype = "auto" block_size = 16 if envs.VLLM_USE_V1: underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, block_size) attn_backend = create_encoder_only_attention_backend( underlying_attn_backend) else: # in v0 encoder only attention is handled inside the backends attn_backend = None if attn_type is not None: assert attn_type == AttentionType.ENCODER_ONLY, \ "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" super().__init__(num_heads=num_heads, head_size=head_size, scale=scale, cache_config=cache_config, attn_backend=attn_backend, attn_type=AttentionType.ENCODER_ONLY, **kwargs)