51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from vllm.v1.outputs import PoolerOutput
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
|
|
|
|
|
|
class ClassifierPooler(Pooler):
|
|
def forward(
|
|
self,
|
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> PoolerOutput:
|
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
|
if isinstance(pooled_data, list):
|
|
pooled_data = torch.stack(pooled_data)
|
|
# pooled_data shape: [batchsize, hidden_size]
|
|
|
|
if pooled_data.dtype != self.head_dtype:
|
|
pooled_data = pooled_data.to(self.head_dtype)
|
|
|
|
if self.classifier is not None:
|
|
pooled_data = self.classifier(pooled_data)
|
|
# pooled_data shape: [batchsize, num_labels]
|
|
|
|
if self.logit_bias is not None:
|
|
pooled_data -= self.logit_bias
|
|
|
|
pooling_params = get_pooling_params(pooling_metadata)
|
|
flags = [p.activation for p in pooling_params]
|
|
|
|
if len(set(flags)) == 1:
|
|
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
|
else:
|
|
scores = [
|
|
self.act_fn(vecs) if f else vecs
|
|
for vecs, f in zip(pooled_data, flags)
|
|
]
|
|
|
|
# scores shape: [batchsize, num_labels]
|
|
return scores
|
|
|
|
|
|
class PoolerNormalize(PoolerActivation):
|
|
|
|
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
|
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12) |