134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Set
|
|
from typing import TypeAlias
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
|
|
from .methods import TokenPoolingMethodOutputItem
|
|
|
|
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
|
|
|
|
|
|
class TokenPoolerHead(nn.Module, ABC):
|
|
@abstractmethod
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward_chunk(
|
|
self,
|
|
pooled_data: TokenPoolingMethodOutputItem,
|
|
pooling_param: PoolingParams,
|
|
) -> TokenPoolerHeadOutputItem:
|
|
raise NotImplementedError
|
|
|
|
def forward(
|
|
self,
|
|
pooled_data: list[TokenPoolingMethodOutputItem],
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> list[TokenPoolerHeadOutputItem]:
|
|
pooling_params = pooling_metadata.pooling_params
|
|
assert len(pooled_data) == len(pooling_params)
|
|
|
|
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
|
|
|
|
|
|
class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
|
def __init__(
|
|
self,
|
|
head_dtype: torch.dtype | str | None = None,
|
|
projector: ProjectorFn | None = None,
|
|
activation: ActivationFn | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.head_dtype = head_dtype
|
|
self.projector = projector
|
|
self.activation = activation
|
|
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return {"token_embed"}
|
|
|
|
def forward_chunk(
|
|
self,
|
|
pooled_data: TokenPoolingMethodOutputItem,
|
|
pooling_param: PoolingParams,
|
|
) -> TokenPoolerHeadOutputItem:
|
|
# for unfinished chunked prefill
|
|
if pooled_data is None:
|
|
return None
|
|
|
|
if self.head_dtype is not None:
|
|
pooled_data = pooled_data.to(self.head_dtype)
|
|
# pooled_data shape: [n_tokens, hidden_dimension]
|
|
|
|
# Apply ST projector
|
|
if self.projector is not None:
|
|
pooled_data = self.projector(pooled_data)
|
|
# pooled_data shape: [n_tokens, embedding_dimension]
|
|
|
|
# for matryoshka representation
|
|
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
|
|
|
# for normalize
|
|
if self.activation is not None and pooling_param.use_activation:
|
|
pooled_data = self.activation(pooled_data)
|
|
|
|
# pooled_data shape: [n_tokens, embedding_dimension]
|
|
return pooled_data
|
|
|
|
|
|
class TokenClassifierPoolerHead(TokenPoolerHead):
|
|
def __init__(
|
|
self,
|
|
classifier: ClassifierFn | None = None,
|
|
logit_bias: float | None = None,
|
|
head_dtype: torch.dtype | str | None = None,
|
|
activation: ActivationFn | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.classifier = classifier
|
|
self.logit_bias = logit_bias
|
|
self.head_dtype = head_dtype
|
|
self.activation = activation
|
|
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return {"token_classify"}
|
|
|
|
def forward_chunk(
|
|
self,
|
|
pooled_data: TokenPoolingMethodOutputItem,
|
|
pooling_param: PoolingParams,
|
|
) -> TokenPoolerHeadOutputItem:
|
|
# for unfinished chunked prefill
|
|
if pooled_data is None:
|
|
return None
|
|
|
|
if self.head_dtype is not None:
|
|
pooled_data = pooled_data.to(self.head_dtype)
|
|
# hidden_states shape: [n_token, hidden_size]
|
|
|
|
if self.classifier is not None:
|
|
scores = self.classifier(pooled_data)
|
|
else:
|
|
scores = pooled_data
|
|
# scores shape: [n_token, num_labels]
|
|
|
|
if self.logit_bias is not None:
|
|
scores -= self.logit_bias
|
|
|
|
if self.activation is not None and pooling_param.use_activation:
|
|
scores = self.activation(scores)
|
|
|
|
# scores shape: [n_token, num_labels]
|
|
return scores
|