# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import cache from typing import NamedTuple, cast, get_args import torch from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.registry import ( MAMBA_TYPE_TO_BACKEND_MAP, MambaAttentionBackendEnum, ) from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) class AttentionSelectorConfig(NamedTuple): head_size: int dtype: torch.dtype kv_cache_dtype: CacheDType | None block_size: int | None use_mla: bool = False has_sink: bool = False use_sparse: bool = False use_mm_prefix: bool = False attn_type: str = AttentionType.DECODER def __repr__(self): return ( f"AttentionSelectorConfig(head_size={self.head_size}, " f"dtype={self.dtype}, " f"kv_cache_dtype={self.kv_cache_dtype}, " f"block_size={self.block_size}, " f"use_mla={self.use_mla}, " f"has_sink={self.has_sink}, " f"use_sparse={self.use_sparse}, " f"use_mm_prefix={self.use_mm_prefix}, " f"attn_type={self.attn_type})" ) def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, use_mm_prefix: bool = False, attn_type: str | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" if kv_cache_dtype is not None: valid_cache_dtypes = get_args(CacheDType) assert kv_cache_dtype in valid_cache_dtypes, ( f"Invalid kv_cache_dtype: {kv_cache_dtype}. " f"Valid values are: {valid_cache_dtypes}" ) from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() backend_enum = vllm_config.attention_config.backend attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), block_size=block_size, use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, use_mm_prefix=use_mm_prefix, attn_type=attn_type or AttentionType.DECODER, ) return _cached_get_attn_backend( backend=backend_enum, attn_selector_config=attn_selector_config, ) @cache def _cached_get_attn_backend( backend, attn_selector_config: AttentionSelectorConfig, ) -> type[AttentionBackend]: from vllm.platforms import current_platform attention_cls = current_platform.get_attn_backend_cls( backend, attn_selector_config=attn_selector_config, ) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) backend = resolve_obj_by_qualname(attention_cls) # Adjust kv cache layout if the selected backend requires a specific one required_layout = backend.get_required_kv_cache_layout() if required_layout is not None: from vllm.v1.attention.backends.utils import set_kv_cache_layout set_kv_cache_layout(required_layout) logger.info( "Using %s KV cache layout for %s backend.", required_layout, backend.get_name(), ) return backend def get_mamba_attn_backend( mamba_type: str, ) -> type[AttentionBackend]: """Select which mamba attention backend to use and lazily import it.""" return _cached_get_mamba_attn_backend(mamba_type) @cache def _cached_get_mamba_attn_backend( mamba_type: str, ) -> type[AttentionBackend]: assert mamba_type and isinstance(mamba_type, str) selected_backend = None try: backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type] selected_backend = MambaAttentionBackendEnum[backend_name] except KeyError as e: raise ValueError( f"Invalid mamba attention backend type: '{backend_name}'. Valid " f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}" ) from e mamba_attn_backend = selected_backend.get_class() return mamba_attn_backend