# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, runtime_checkable) import torch import torch.nn as nn from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger from vllm.utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata logger = init_logger(__name__) # The type of hidden states # Currently, T = torch.Tensor for all models except for Medusa # which has T = list[torch.Tensor] T = TypeVar("T", default=torch.Tensor) T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) # NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags # for the base interfaces to avoid breaking OOT registration for existing models # that don't inherit from the base interface classes @runtime_checkable class VllmModel(Protocol[T_co]): """The interface required for all models in vLLM.""" def __init__( self, vllm_config: "VllmConfig", prefix: str = "", ) -> None: ... def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, ) -> T_co: ... def _check_vllm_model_init(model: Union[type[object], object]) -> bool: model_init = model.__init__ return supports_kw(model_init, "vllm_config") def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False vllm_kws = ("input_ids", "positions") missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " "vLLM-specific keywords from its `forward` method: %s", model, missing_kws, ) return len(missing_kws) == 0 @overload def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ... @overload def is_vllm_model(model: object) -> TypeIs[VllmModel]: ... def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: return _check_vllm_model_init(model) and _check_vllm_model_forward(model) @runtime_checkable class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): """The interface required for all generative models in vLLM.""" def compute_logits( self, hidden_states: T, sampling_metadata: "SamplingMetadata", ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... @overload def is_text_generation_model( model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]: ... @overload def is_text_generation_model( model: object) -> TypeIs[VllmModelForTextGeneration]: ... def is_text_generation_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]]: if not is_vllm_model(model): return False if isinstance(model, type): return isinstance(model, VllmModelForTextGeneration) return isinstance(model, VllmModelForTextGeneration) @runtime_checkable class VllmModelForPooling(VllmModel[T], Protocol[T]): """The interface required for all pooling models in vLLM.""" def pooler( self, hidden_states: T, pooling_metadata: "PoolingMetadata", ) -> "PoolerOutput": """Only called on TP rank 0.""" ... @overload def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ... @overload def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... def is_pooling_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: if not is_vllm_model(model): return False if isinstance(model, type): return isinstance(model, VllmModelForPooling) return isinstance(model, VllmModelForPooling)