174 lines
5.3 KiB
Python
174 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Mapping, Set
|
|
from itertools import groupby
|
|
|
|
import torch
|
|
|
|
from vllm.config import PoolerConfig
|
|
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
|
|
from .abstract import Pooler, PoolerOutput
|
|
from .common import ClassifierFn
|
|
from .seqwise import (
|
|
SequencePoolingFn,
|
|
SequencePoolingMethod,
|
|
pooler_for_classify,
|
|
pooler_for_embed,
|
|
)
|
|
from .tokwise import AllPool, pooler_for_token_classify, pooler_for_token_embed
|
|
|
|
|
|
class DispatchPooler(Pooler):
|
|
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
|
|
|
@classmethod
|
|
def for_embedding(cls, pooler_config: PoolerConfig):
|
|
return cls(
|
|
{
|
|
"token_embed": pooler_for_token_embed(pooler_config),
|
|
"embed": pooler_for_embed(pooler_config),
|
|
},
|
|
)
|
|
|
|
@classmethod
|
|
def for_seq_cls(
|
|
cls,
|
|
pooler_config: PoolerConfig,
|
|
*,
|
|
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
|
|
classifier: ClassifierFn | None = None,
|
|
):
|
|
return cls(
|
|
{
|
|
"token_classify": pooler_for_token_classify(
|
|
pooler_config,
|
|
pooling=AllPool(),
|
|
classifier=classifier,
|
|
),
|
|
"classify": pooler_for_classify(
|
|
pooler_config,
|
|
pooling=pooling,
|
|
classifier=classifier,
|
|
act_fn="classify",
|
|
),
|
|
"score": pooler_for_classify(
|
|
pooler_config,
|
|
pooling=pooling,
|
|
classifier=classifier,
|
|
act_fn="score",
|
|
),
|
|
}
|
|
)
|
|
|
|
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
|
super().__init__()
|
|
|
|
for task, pooler in poolers_by_task.items():
|
|
if task not in pooler.get_supported_tasks():
|
|
raise ValueError(
|
|
f"{pooler=} does not support {task=}. "
|
|
f"Supported tasks: {pooler.get_supported_tasks()}"
|
|
)
|
|
|
|
self.poolers_by_task = poolers_by_task
|
|
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return set(self.poolers_by_task)
|
|
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return self.poolers_by_task[task].get_pooling_updates(task)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> PoolerOutput:
|
|
poolers_by_task = self.poolers_by_task
|
|
|
|
outputs = list[torch.Tensor | None]()
|
|
offset = 0
|
|
for task, group in groupby(pooling_metadata.tasks):
|
|
if not (pooler := poolers_by_task.get(task)):
|
|
raise ValueError(
|
|
f"Unsupported task: {task!r} "
|
|
f"Supported tasks: {self.get_supported_tasks()}"
|
|
)
|
|
|
|
num_items = len(list(group))
|
|
group_output: PoolerOutput = pooler(
|
|
hidden_states,
|
|
pooling_metadata[offset : offset + num_items],
|
|
)
|
|
|
|
outputs.extend(group_output)
|
|
offset += num_items
|
|
|
|
return outputs
|
|
|
|
def extra_repr(self) -> str:
|
|
s = f"supported_task={self.get_supported_tasks()}"
|
|
return s
|
|
|
|
|
|
class IdentityPooler(Pooler):
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return {"plugin", "score"}
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> PoolerOutput:
|
|
return hidden_states
|
|
|
|
|
|
class BOSEOSFilter(Pooler):
|
|
"""Filters the BOS and EOS token results from outputs."""
|
|
|
|
def __init__(
|
|
self,
|
|
pooler: Pooler,
|
|
bos_token_id: int = -1, # -1 disables the filtering
|
|
eos_token_id: int = -1,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.pooler = pooler
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return self.pooler.get_supported_tasks()
|
|
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return PoolingParamsUpdate(requires_token_ids=True)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> PoolerOutput:
|
|
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
|
|
assert isinstance(pooled_outputs, list)
|
|
|
|
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
|
pooled_data = pooled_outputs[i]
|
|
assert (
|
|
isinstance(pooled_data, torch.Tensor)
|
|
and pooled_data.shape[0] == prompt_len
|
|
)
|
|
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
|
|
if token_ids[0] == self.bos_token_id:
|
|
pooled_data = pooled_data[1:]
|
|
if token_ids[-1] == self.eos_token_id:
|
|
pooled_data = pooled_data[:-1]
|
|
pooled_outputs[i] = pooled_data.squeeze(-1)
|
|
|
|
return pooled_outputs
|
|
|
|
|
|
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|