# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from typing import ClassVar, List, Optional import torch from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, CommonAttentionMetadata, make_local_attention_virtual_batches, subclass_attention_backend) from ..layer import Attention @functools.lru_cache def create_chunked_local_attention_backend( underlying_attn_backend: AttentionBackend, attention_chunk_size: int, block_size: int, ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" underlying_builder = underlying_attn_backend.get_builder_cls() class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.NEVER def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> AttentionMetadata: common_attn_metadata = make_local_attention_virtual_batches( attention_chunk_size, common_attn_metadata, block_size) return super().build(common_prefix_len, common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, builder_cls=ChunkedLocalAttentionBuilder) return attn_backend class ChunkedLocalAttention(Attention): def __init__(self, num_heads: int, head_size: int, scale: float, attention_chunk_size: int, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, kv_sharing_target_layer_name: Optional[str] = None, prefix: str = ""): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size else: kv_cache_dtype = "auto" block_size = 16 if envs.VLLM_USE_V1: underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, block_size) attn_backend = create_chunked_local_attention_backend( underlying_attn_backend, attention_chunk_size, block_size) else: # in v0 the local attention is handled inside the backends attn_backend = None super().__init__( num_heads=num_heads, head_size=head_size, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config, prefix=prefix, kv_sharing_target_layer_name=kv_sharing_target_layer_name, attn_backend=attn_backend)