# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from typing import Any import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.v1.core.sched.output import NewRequestData from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup class ModelState(ABC): @abstractmethod def __init__( self, vllm_config: VllmConfig, model: nn.Module, encoder_cache: EncoderCache | None, device: torch.device, ) -> None: raise NotImplementedError @abstractmethod def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: raise NotImplementedError @abstractmethod def apply_staged_writes(self) -> None: raise NotImplementedError @abstractmethod def get_mm_embeddings( self, scheduled_encoder_inputs: dict[str, list[int]], input_batch: InputBatch, req_states: RequestState, ) -> torch.Tensor: raise NotImplementedError @abstractmethod def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState ) -> dict[str, torch.Tensor | None]: raise NotImplementedError @abstractmethod def prepare_dummy_inputs( self, num_reqs: int, num_tokens: int ) -> dict[str, torch.Tensor | None]: raise NotImplementedError @abstractmethod def prepare_attn( self, input_batch: InputBatch, block_tables: tuple[torch.Tensor, ...], slot_mappings: torch.Tensor, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, ) -> dict[str, Any]: raise NotImplementedError