Files
bi_150-vllm/vllm/model_executor/layers/pooler/tokwise/heads.py

134 lines
4.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .methods import TokenPoolingMethodOutputItem
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
class TokenPoolerHead(nn.Module, ABC):
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
@abstractmethod
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
raise NotImplementedError
def forward(
self,
pooled_data: list[TokenPoolingMethodOutputItem],
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolerHeadOutputItem]:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(
self,
head_dtype: torch.dtype | str | None = None,
projector: ProjectorFn | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.head_dtype = head_dtype
self.projector = projector
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if self.activation is not None and pooling_param.use_activation:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.classifier = classifier
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(pooled_data)
else:
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if self.activation is not None and pooling_param.use_activation:
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores