104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
from copy import copy
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import (
|
|
AttentionBackend,
|
|
AttentionMetadata,
|
|
AttentionType,
|
|
)
|
|
from vllm.attention.layer import Attention
|
|
from vllm.attention.selector import get_attn_backend
|
|
from vllm.config import CacheConfig
|
|
from vllm.config.vllm import VllmConfig
|
|
from vllm.v1.attention.backends.utils import (
|
|
CommonAttentionMetadata,
|
|
subclass_attention_backend,
|
|
)
|
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
|
|
|
|
|
@functools.lru_cache
|
|
def create_encoder_only_attention_backend(
|
|
underlying_attn_backend: AttentionBackend,
|
|
) -> type[AttentionBackend]:
|
|
prefix = "EncoderOnlyAttention_"
|
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
|
|
|
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> AttentionMetadata:
|
|
new_common_attn_metadata = copy(common_attn_metadata)
|
|
new_common_attn_metadata.causal = False
|
|
return super().build(
|
|
common_prefix_len, new_common_attn_metadata, fast_build
|
|
)
|
|
|
|
attn_backend = subclass_attention_backend(
|
|
name_prefix=prefix,
|
|
attention_backend_cls=underlying_attn_backend,
|
|
builder_cls=EncoderOnlyAttentionBuilder,
|
|
)
|
|
|
|
return attn_backend
|
|
|
|
|
|
class EncoderOnlyAttention(Attention):
|
|
"""
|
|
Encoder attention is a special case that doesn't need a KV Cache.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
cache_config: CacheConfig | None = None,
|
|
attn_type: str | None = None,
|
|
**kwargs,
|
|
):
|
|
dtype = torch.get_default_dtype()
|
|
|
|
if cache_config is not None:
|
|
kv_cache_dtype = cache_config.cache_dtype
|
|
block_size = cache_config.block_size
|
|
else:
|
|
kv_cache_dtype = "auto"
|
|
block_size = 16
|
|
|
|
underlying_attn_backend = get_attn_backend(
|
|
head_size,
|
|
dtype,
|
|
kv_cache_dtype,
|
|
block_size,
|
|
attn_type=AttentionType.ENCODER_ONLY,
|
|
)
|
|
|
|
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
|
|
|
if attn_type is not None:
|
|
assert attn_type == AttentionType.ENCODER_ONLY, (
|
|
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
|
)
|
|
|
|
super().__init__(
|
|
num_heads=num_heads,
|
|
head_size=head_size,
|
|
scale=scale,
|
|
cache_config=cache_config,
|
|
attn_backend=attn_backend,
|
|
attn_type=AttentionType.ENCODER_ONLY,
|
|
**kwargs,
|
|
)
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
# Does not need KV cache
|
|
return None
|