Files
2026-03-10 13:31:25 +08:00

43 lines
1.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
"""
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
kv_cache: tuple[torch.Tensor, ...]
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
"""
Defines the shape of the state.
For mamba layers this is usually a (conv_state, ssm_state) tuple.
In this case, returns (conv_state_shape, ssm_state_shape).
"""
pass
@property
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass