Files
2026-04-02 04:55:00 +00:00

51 lines
1.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolerActivation, get_pooling_params
class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if pooled_data.dtype != self.head_dtype:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return torch.vacc.l2_norm(pooled_data, epsilon=1e-12)