# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend mixins for pooling/embedding models. This module provides mixins for embedding and sequence classification models: - EmbeddingMixin: For embedding/sentence similarity models - SequenceClassificationMixin: For sequence classification/cross-encoding Following latest vLLM architecture patterns adapted for v0.6.2. """ from typing import TYPE_CHECKING, List, Optional import torch import torch.nn as nn from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput if TYPE_CHECKING: from vllm.config import VllmConfig logger = init_logger(__name__) class EmbeddingMixin: """ Mixin class that adds embedding/pooling functionality. This mixin provides: - Pooler layer for extracting embeddings - pooling method for VllmModelForPooling protocol Should be used with Base class: class TransformersForEmbedding(EmbeddingMixin, Base): ... """ # Default pooling configuration default_pooling_type: PoolingType = PoolingType.CLS default_normalize: bool = True default_softmax: bool = False def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: # Call next class in MRO (should be Base) super().__init__(vllm_config=vllm_config, prefix=prefix) # Get pooler config from model config pooler_config = vllm_config.model_config.pooler_config # Setup pooler self.pooler = Pooler.from_config_with_defaults( pooler_config=pooler_config, pooling_type=self.default_pooling_type, normalize=self.default_normalize, softmax=self.default_softmax, ) if self.pooler is None: # Create default pooler if config doesn't specify self.pooler = Pooler( pooling_type=self.default_pooling_type, normalize=self.default_normalize, softmax=self.default_softmax, ) logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)", self.pooler.pooling_type.name, self.pooler.normalize) def pooling( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: """ Apply pooling to hidden states. Args: hidden_states: Hidden states from the model [seq_len, hidden_size] pooling_metadata: Pooling metadata Returns: PoolerOutput with pooled embeddings """ return self.pooler(hidden_states, pooling_metadata) class SequenceClassificationMixin(EmbeddingMixin): """ Mixin class that adds sequence classification functionality. This mixin provides: - Classifier layer for sequence classification - pooling method with classification logits Should be used with Base class: class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ... """ default_pooling_type: PoolingType = PoolingType.CLS default_normalize: bool = False default_softmax: bool = True def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: # Call EmbeddingMixin.__init__ -> Base.__init__ super().__init__(vllm_config=vllm_config, prefix=prefix) # Find and setup classifier layer self.classifier = self._find_classifier() if self.classifier is not None: # Initialize classifier parameters on device self._init_classifier_params() logger.info("SequenceClassificationMixin initialized with classifier") else: logger.warning("Could not find classifier layer") def _find_classifier(self) -> Optional[nn.Module]: """Find the classifier layer in the model.""" # Common classifier layer names classifier_names = ['classifier', 'score', 'fc', 'head'] for name in classifier_names: if hasattr(self.model, name): return getattr(self.model, name) return None def _init_classifier_params(self) -> None: """Initialize classifier parameters on target device.""" device = self.device_config.device if device is None: device = torch.device("cpu") dtype = self.model_config.dtype for name, param in list(self.classifier.named_parameters()): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, dtype=dtype, device=device), requires_grad=False, ) setattr(self.classifier, name.split('.')[-1], new_param) def pooling( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: """ Apply pooling and classification to hidden states. Args: hidden_states: Hidden states from the model [seq_len, hidden_size] pooling_metadata: Pooling metadata Returns: PoolerOutput with classification logits """ # First apply base pooling pooled = self.pooler(hidden_states, pooling_metadata) # Apply classifier if available if self.classifier is not None and pooled is not None: # Apply classifier to each pooled output for i, output in enumerate(pooled.outputs): if hasattr(output, 'data'): output.data = self.classifier(output.data) return pooled