94 lines
3.5 KiB
Python
94 lines
3.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
from typing import ClassVar, List, Optional
|
|
|
|
import torch
|
|
|
|
from vllm import envs
|
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
|
AttentionMetadata)
|
|
from vllm.attention.selector import get_attn_backend
|
|
from vllm.config import CacheConfig, QuantizationConfig
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionCGSupport, CommonAttentionMetadata,
|
|
make_local_attention_virtual_batches, subclass_attention_backend)
|
|
|
|
from ..layer import Attention
|
|
|
|
|
|
@functools.lru_cache
|
|
def create_chunked_local_attention_backend(
|
|
underlying_attn_backend: AttentionBackend,
|
|
attention_chunk_size: int,
|
|
block_size: int,
|
|
) -> type[AttentionBackend]:
|
|
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
|
|
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
|
|
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
AttentionCGSupport.NEVER
|
|
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> AttentionMetadata:
|
|
common_attn_metadata = make_local_attention_virtual_batches(
|
|
attention_chunk_size, common_attn_metadata, block_size)
|
|
return super().build(common_prefix_len, common_attn_metadata,
|
|
fast_build)
|
|
|
|
attn_backend = subclass_attention_backend(
|
|
name_prefix=prefix,
|
|
attention_backend_cls=underlying_attn_backend,
|
|
builder_cls=ChunkedLocalAttentionBuilder)
|
|
|
|
return attn_backend
|
|
|
|
|
|
class ChunkedLocalAttention(Attention):
|
|
|
|
def __init__(self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
attention_chunk_size: int,
|
|
num_kv_heads: Optional[int] = None,
|
|
alibi_slopes: Optional[List[float]] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
kv_sharing_target_layer_name: Optional[str] = None,
|
|
prefix: str = ""):
|
|
dtype = torch.get_default_dtype()
|
|
if cache_config is not None:
|
|
kv_cache_dtype = cache_config.cache_dtype
|
|
block_size = cache_config.block_size
|
|
else:
|
|
kv_cache_dtype = "auto"
|
|
block_size = 16
|
|
|
|
if envs.VLLM_USE_V1:
|
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
|
kv_cache_dtype,
|
|
block_size)
|
|
|
|
attn_backend = create_chunked_local_attention_backend(
|
|
underlying_attn_backend, attention_chunk_size, block_size)
|
|
else:
|
|
# in v0 the local attention is handled inside the backends
|
|
attn_backend = None
|
|
|
|
super().__init__(
|
|
num_heads=num_heads,
|
|
head_size=head_size,
|
|
scale=scale,
|
|
num_kv_heads=num_kv_heads,
|
|
alibi_slopes=alibi_slopes,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
|
attn_backend=attn_backend)
|