Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py
2026-02-05 18:02:59 +08:00

171 lines
5.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixins for pooling/embedding models.
This module provides mixins for embedding and sequence classification models:
- EmbeddingMixin: For embedding/sentence similarity models
- SequenceClassificationMixin: For sequence classification/cross-encoding
Following latest vLLM architecture patterns adapted for v0.6.2.
"""
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class EmbeddingMixin:
"""
Mixin class that adds embedding/pooling functionality.
This mixin provides:
- Pooler layer for extracting embeddings
- pooling method for VllmModelForPooling protocol
Should be used with Base class:
class TransformersForEmbedding(EmbeddingMixin, Base): ...
"""
# Default pooling configuration
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = True
default_softmax: bool = False
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Get pooler config from model config
pooler_config = vllm_config.model_config.pooler_config
# Setup pooler
self.pooler = Pooler.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
if self.pooler is None:
# Create default pooler if config doesn't specify
self.pooler = Pooler(
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)",
self.pooler.pooling_type.name, self.pooler.normalize)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with pooled embeddings
"""
return self.pooler(hidden_states, pooling_metadata)
class SequenceClassificationMixin(EmbeddingMixin):
"""
Mixin class that adds sequence classification functionality.
This mixin provides:
- Classifier layer for sequence classification
- pooling method with classification logits
Should be used with Base class:
class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ...
"""
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = False
default_softmax: bool = True
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call EmbeddingMixin.__init__ -> Base.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Find and setup classifier layer
self.classifier = self._find_classifier()
if self.classifier is not None:
# Initialize classifier parameters on device
self._init_classifier_params()
logger.info("SequenceClassificationMixin initialized with classifier")
else:
logger.warning("Could not find classifier layer")
def _find_classifier(self) -> Optional[nn.Module]:
"""Find the classifier layer in the model."""
# Common classifier layer names
classifier_names = ['classifier', 'score', 'fc', 'head']
for name in classifier_names:
if hasattr(self.model, name):
return getattr(self.model, name)
return None
def _init_classifier_params(self) -> None:
"""Initialize classifier parameters on target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
for name, param in list(self.classifier.named_parameters()):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(self.classifier, name.split('.')[-1], new_param)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling and classification to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with classification logits
"""
# First apply base pooling
pooled = self.pooler(hidden_states, pooling_metadata)
# Apply classifier if available
if self.classifier is not None and pooled is not None:
# Apply classifier to each pooled output
for i, output in enumerate(pooled.outputs):
if hasattr(output, 'data'):
output.data = self.classifier(output.data)
return pooled