update
This commit is contained in:
151
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
151
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .methods import SequencePoolingMethodOutput
|
||||
|
||||
SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolerHead(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
projector: ProjectorFn | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.projector = projector
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
|
||||
# if all dimensions are the same
|
||||
d = dimensions_list[0]
|
||||
pooled_data = pooled_data[..., :d]
|
||||
else:
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
# for normalize
|
||||
if self.activation is not None:
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPoolerHead(SequencePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
logit_bias: float | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias = logit_bias
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
if self.activation is not None:
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
return pooled_data
|
||||
Reference in New Issue
Block a user