testing dynamic register
This commit is contained in:
170
vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py
Normal file
170
vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright 2024 The vLLM team.
|
||||
"""Transformers modeling backend mixins for pooling/embedding models.
|
||||
|
||||
This module provides mixins for embedding and sequence classification models:
|
||||
- EmbeddingMixin: For embedding/sentence similarity models
|
||||
- SequenceClassificationMixin: For sequence classification/cross-encoding
|
||||
|
||||
Following latest vLLM architecture patterns adapted for v0.6.2.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import PoolerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingMixin:
|
||||
"""
|
||||
Mixin class that adds embedding/pooling functionality.
|
||||
|
||||
This mixin provides:
|
||||
- Pooler layer for extracting embeddings
|
||||
- pooling method for VllmModelForPooling protocol
|
||||
|
||||
Should be used with Base class:
|
||||
class TransformersForEmbedding(EmbeddingMixin, Base): ...
|
||||
"""
|
||||
|
||||
# Default pooling configuration
|
||||
default_pooling_type: PoolingType = PoolingType.CLS
|
||||
default_normalize: bool = True
|
||||
default_softmax: bool = False
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
||||
# Call next class in MRO (should be Base)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# Get pooler config from model config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
# Setup pooler
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=self.default_pooling_type,
|
||||
normalize=self.default_normalize,
|
||||
softmax=self.default_softmax,
|
||||
)
|
||||
|
||||
if self.pooler is None:
|
||||
# Create default pooler if config doesn't specify
|
||||
self.pooler = Pooler(
|
||||
pooling_type=self.default_pooling_type,
|
||||
normalize=self.default_normalize,
|
||||
softmax=self.default_softmax,
|
||||
)
|
||||
|
||||
logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)",
|
||||
self.pooler.pooling_type.name, self.pooler.normalize)
|
||||
|
||||
def pooling(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
"""
|
||||
Apply pooling to hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states from the model [seq_len, hidden_size]
|
||||
pooling_metadata: Pooling metadata
|
||||
|
||||
Returns:
|
||||
PoolerOutput with pooled embeddings
|
||||
"""
|
||||
return self.pooler(hidden_states, pooling_metadata)
|
||||
|
||||
|
||||
class SequenceClassificationMixin(EmbeddingMixin):
|
||||
"""
|
||||
Mixin class that adds sequence classification functionality.
|
||||
|
||||
This mixin provides:
|
||||
- Classifier layer for sequence classification
|
||||
- pooling method with classification logits
|
||||
|
||||
Should be used with Base class:
|
||||
class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ...
|
||||
"""
|
||||
|
||||
default_pooling_type: PoolingType = PoolingType.CLS
|
||||
default_normalize: bool = False
|
||||
default_softmax: bool = True
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
||||
# Call EmbeddingMixin.__init__ -> Base.__init__
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# Find and setup classifier layer
|
||||
self.classifier = self._find_classifier()
|
||||
|
||||
if self.classifier is not None:
|
||||
# Initialize classifier parameters on device
|
||||
self._init_classifier_params()
|
||||
logger.info("SequenceClassificationMixin initialized with classifier")
|
||||
else:
|
||||
logger.warning("Could not find classifier layer")
|
||||
|
||||
def _find_classifier(self) -> Optional[nn.Module]:
|
||||
"""Find the classifier layer in the model."""
|
||||
# Common classifier layer names
|
||||
classifier_names = ['classifier', 'score', 'fc', 'head']
|
||||
|
||||
for name in classifier_names:
|
||||
if hasattr(self.model, name):
|
||||
return getattr(self.model, name)
|
||||
|
||||
return None
|
||||
|
||||
def _init_classifier_params(self) -> None:
|
||||
"""Initialize classifier parameters on target device."""
|
||||
device = self.device_config.device
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
dtype = self.model_config.dtype
|
||||
|
||||
for name, param in list(self.classifier.named_parameters()):
|
||||
if param.device == torch.device("meta"):
|
||||
new_param = nn.Parameter(
|
||||
torch.empty_like(param.data, dtype=dtype, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
setattr(self.classifier, name.split('.')[-1], new_param)
|
||||
|
||||
def pooling(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
"""
|
||||
Apply pooling and classification to hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states from the model [seq_len, hidden_size]
|
||||
pooling_metadata: Pooling metadata
|
||||
|
||||
Returns:
|
||||
PoolerOutput with classification logits
|
||||
"""
|
||||
# First apply base pooling
|
||||
pooled = self.pooler(hidden_states, pooling_metadata)
|
||||
|
||||
# Apply classifier if available
|
||||
if self.classifier is not None and pooled is not None:
|
||||
# Apply classifier to each pooled output
|
||||
for i, output in enumerate(pooled.outputs):
|
||||
if hasattr(output, 'data'):
|
||||
output.data = self.classifier(output.data)
|
||||
|
||||
return pooled
|
||||
Reference in New Issue
Block a user