first commit
This commit is contained in:
42
vllm/model_executor/layers/mamba/abstract.py
Normal file
42
vllm/model_executor/layers/mamba/abstract.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MambaBase(AttentionLayerBase):
|
||||
"""
|
||||
Base class for Mamba-like layers which support the v1 engine.
|
||||
Inherit from this class if you implement a custom layer.
|
||||
"""
|
||||
|
||||
# Contains the KV cache (mamba state) for the layer
|
||||
# in the shape specified by `self.get_state_shape`.
|
||||
kv_cache: tuple[torch.Tensor, ...]
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||
"""
|
||||
Defines the shape of the state.
|
||||
For mamba layers this is usually a (conv_state, ssm_state) tuple.
|
||||
In this case, returns (conv_state_shape, ssm_state_shape).
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
Reference in New Issue
Block a user