# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Iterable import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. """ # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. kv_cache: tuple[torch.Tensor, ...] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: """ Defines the shape of the state. For mamba layers this is usually a (conv_state, ssm_state) tuple. In this case, returns (conv_state_shape, ssm_state_shape). """ pass @property @abstractmethod def mamba_type(self) -> str: pass @abstractmethod def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( vllm_config.speculative_config is not None and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"] ): raise NotImplementedError( "Mamba with speculative decoding is not supported yet." ) mamba_block_size = vllm_config.cache_config.mamba_block_size page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( shapes=self.get_state_shape(), dtypes=self.get_state_dtype(), block_size=mamba_block_size, page_size_padded=page_size_padded, mamba_type=self.mamba_type, num_speculative_blocks=( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0 ), ) def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type)