Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
133
vllm/model_executor/layers/pooler/tokwise/heads.py
Normal file
133
vllm/model_executor/layers/pooler/tokwise/heads.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .methods import TokenPoolingMethodOutputItem
|
||||
|
||||
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class TokenPoolerHead(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[TokenPoolingMethodOutputItem],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[TokenPoolerHeadOutputItem]:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
projector: ProjectorFn | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_dtype = head_dtype
|
||||
self.projector = projector
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed"}
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if self.activation is not None and pooling_param.use_activation:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(TokenPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
logit_bias: float | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias = logit_bias
|
||||
self.head_dtype = head_dtype
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
if self.head_dtype is not None:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(pooled_data)
|
||||
else:
|
||||
scores = pooled_data
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if self.activation is not None and pooling_param.use_activation:
|
||||
scores = self.activation(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
Reference in New Issue
Block a user