# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Union import torch from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params class ClassifierPooler(Pooler): def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] if pooled_data.dtype != self.head_dtype: pooled_data = pooled_data.to(self.head_dtype) if self.classifier is not None: pooled_data = self.classifier(pooled_data) # pooled_data shape: [batchsize, num_labels] if self.logit_bias is not None: pooled_data -= self.logit_bias pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] if len(set(flags)) == 1: scores = self.act_fn(pooled_data) if flags[0] else pooled_data else: scores = [ self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) ] # scores shape: [batchsize, num_labels] return scores class PoolerNormalize(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)