# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING import torch from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. """ # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. kv_cache: tuple[torch.Tensor, ...] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: """ Defines the shape of the state. For mamba layers this is usually a (conv_state, ssm_state) tuple. In this case, returns (conv_state_shape, ssm_state_shape). """ pass @property @abstractmethod def mamba_type(self) -> str: pass @abstractmethod def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this Mamba layer.""" pass