68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.core.sched.output import NewRequestData
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
|
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
|
from vllm.v1.worker.gpu.states import RequestState
|
|
from vllm.v1.worker.utils import AttentionGroup
|
|
|
|
|
|
class ModelState(ABC):
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
model: nn.Module,
|
|
encoder_cache: EncoderCache | None,
|
|
device: torch.device,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def apply_staged_writes(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_mm_embeddings(
|
|
self,
|
|
scheduled_encoder_inputs: dict[str, list[int]],
|
|
input_batch: InputBatch,
|
|
req_states: RequestState,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def prepare_inputs(
|
|
self, input_batch: InputBatch, req_states: RequestState
|
|
) -> dict[str, torch.Tensor | None]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def prepare_dummy_inputs(
|
|
self, num_reqs: int, num_tokens: int
|
|
) -> dict[str, torch.Tensor | None]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def prepare_attn(
|
|
self,
|
|
input_batch: InputBatch,
|
|
block_tables: tuple[torch.Tensor, ...],
|
|
slot_mappings: torch.Tensor,
|
|
attn_groups: list[list[AttentionGroup]],
|
|
kv_cache_config: KVCacheConfig,
|
|
) -> dict[str, Any]:
|
|
raise NotImplementedError
|