171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright 2024 The vLLM team.
|
|
"""Transformers modeling backend mixins for pooling/embedding models.
|
|
|
|
This module provides mixins for embedding and sequence classification models:
|
|
- EmbeddingMixin: For embedding/sentence similarity models
|
|
- SequenceClassificationMixin: For sequence classification/cross-encoding
|
|
|
|
Following latest vLLM architecture patterns adapted for v0.6.2.
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|
from vllm.sequence import PoolerOutput
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class EmbeddingMixin:
|
|
"""
|
|
Mixin class that adds embedding/pooling functionality.
|
|
|
|
This mixin provides:
|
|
- Pooler layer for extracting embeddings
|
|
- pooling method for VllmModelForPooling protocol
|
|
|
|
Should be used with Base class:
|
|
class TransformersForEmbedding(EmbeddingMixin, Base): ...
|
|
"""
|
|
|
|
# Default pooling configuration
|
|
default_pooling_type: PoolingType = PoolingType.CLS
|
|
default_normalize: bool = True
|
|
default_softmax: bool = False
|
|
|
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
|
# Call next class in MRO (should be Base)
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
# Get pooler config from model config
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
|
|
# Setup pooler
|
|
self.pooler = Pooler.from_config_with_defaults(
|
|
pooler_config=pooler_config,
|
|
pooling_type=self.default_pooling_type,
|
|
normalize=self.default_normalize,
|
|
softmax=self.default_softmax,
|
|
)
|
|
|
|
if self.pooler is None:
|
|
# Create default pooler if config doesn't specify
|
|
self.pooler = Pooler(
|
|
pooling_type=self.default_pooling_type,
|
|
normalize=self.default_normalize,
|
|
softmax=self.default_softmax,
|
|
)
|
|
|
|
logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)",
|
|
self.pooler.pooling_type.name, self.pooler.normalize)
|
|
|
|
def pooling(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> Optional[PoolerOutput]:
|
|
"""
|
|
Apply pooling to hidden states.
|
|
|
|
Args:
|
|
hidden_states: Hidden states from the model [seq_len, hidden_size]
|
|
pooling_metadata: Pooling metadata
|
|
|
|
Returns:
|
|
PoolerOutput with pooled embeddings
|
|
"""
|
|
return self.pooler(hidden_states, pooling_metadata)
|
|
|
|
|
|
class SequenceClassificationMixin(EmbeddingMixin):
|
|
"""
|
|
Mixin class that adds sequence classification functionality.
|
|
|
|
This mixin provides:
|
|
- Classifier layer for sequence classification
|
|
- pooling method with classification logits
|
|
|
|
Should be used with Base class:
|
|
class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ...
|
|
"""
|
|
|
|
default_pooling_type: PoolingType = PoolingType.CLS
|
|
default_normalize: bool = False
|
|
default_softmax: bool = True
|
|
|
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
|
# Call EmbeddingMixin.__init__ -> Base.__init__
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
# Find and setup classifier layer
|
|
self.classifier = self._find_classifier()
|
|
|
|
if self.classifier is not None:
|
|
# Initialize classifier parameters on device
|
|
self._init_classifier_params()
|
|
logger.info("SequenceClassificationMixin initialized with classifier")
|
|
else:
|
|
logger.warning("Could not find classifier layer")
|
|
|
|
def _find_classifier(self) -> Optional[nn.Module]:
|
|
"""Find the classifier layer in the model."""
|
|
# Common classifier layer names
|
|
classifier_names = ['classifier', 'score', 'fc', 'head']
|
|
|
|
for name in classifier_names:
|
|
if hasattr(self.model, name):
|
|
return getattr(self.model, name)
|
|
|
|
return None
|
|
|
|
def _init_classifier_params(self) -> None:
|
|
"""Initialize classifier parameters on target device."""
|
|
device = self.device_config.device
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
dtype = self.model_config.dtype
|
|
|
|
for name, param in list(self.classifier.named_parameters()):
|
|
if param.device == torch.device("meta"):
|
|
new_param = nn.Parameter(
|
|
torch.empty_like(param.data, dtype=dtype, device=device),
|
|
requires_grad=False,
|
|
)
|
|
setattr(self.classifier, name.split('.')[-1], new_param)
|
|
|
|
def pooling(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> Optional[PoolerOutput]:
|
|
"""
|
|
Apply pooling and classification to hidden states.
|
|
|
|
Args:
|
|
hidden_states: Hidden states from the model [seq_len, hidden_size]
|
|
pooling_metadata: Pooling metadata
|
|
|
|
Returns:
|
|
PoolerOutput with classification logits
|
|
"""
|
|
# First apply base pooling
|
|
pooled = self.pooler(hidden_states, pooling_metadata)
|
|
|
|
# Apply classifier if available
|
|
if self.classifier is not None and pooled is not None:
|
|
# Apply classifier to each pooled output
|
|
for i, output in enumerate(pooled.outputs):
|
|
if hasattr(output, 'data'):
|
|
output.data = self.classifier(output.data)
|
|
|
|
return pooled
|