This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract import *
from .common import *
from .special import *

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
import torch
import torch.nn as nn
from vllm.tasks import PoolingTask
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .common import PoolingParamsUpdate
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
__all__ = ["Pooler"]

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_classification_act_fn(
config: PretrainedConfig,
) -> "PoolerActivation":
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify()
def get_cross_encoder_act_fn(
config: PretrainedConfig,
) -> "PoolerActivation":
function_name: str | None = None
if (
hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers
):
function_name = config.sentence_transformers["activation_fn"]
elif (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn)
return PoolerClassify()
def resolve_classifier_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: "PoolerActivation | str | None" = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return get_classification_act_fn(model_config.hf_config)
if act_fn == "score":
return get_cross_encoder_act_fn(model_config.hf_config)
raise ValueError(f"act_fn [{act_fn=}] not supported.")
if act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
assert callable(act_fn)
return act_fn
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
class PoolerActivation(nn.Module, ABC):
@staticmethod
def wraps(module: nn.Module):
if isinstance(module, nn.Identity):
return PoolerIdentity()
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
return PoolerClassify()
return LambdaPoolerActivation(module)
@abstractmethod
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
return self.forward_chunk(pooled_data)
class PoolerIdentity(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return pooled_data
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.normalize(pooled_data, p=2, dim=-1)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation):
def __init__(self, *, static_num_labels: bool = True) -> None:
super().__init__()
if static_num_labels:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
num_labels = getattr(model_config.hf_config, "num_labels", 0)
else:
num_labels = None
if num_labels == 0:
logger.warning(
"num_labels should be > 0 for classification "
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
self.num_labels = num_labels
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = self.num_labels
if num_labels is None:
num_labels = pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data)
return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation):
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
super().__init__()
self.fn = fn
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return self.fn(pooled_data)

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import torch
from vllm.pooling_params import PoolingParams
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
return PoolingParamsUpdate(
requires_token_ids=self.requires_token_ids or other.requires_token_ids,
)
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output aggregating all tokens in the sequence."""
from .heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
SequencePoolerHead,
SequencePoolerHeadOutput,
)
from .methods import (
CLSPool,
LastPool,
MeanPool,
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from .poolers import (
SequencePooler,
SequencePoolerOutput,
SequencePoolingFn,
SequencePoolingHeadFn,
pooler_for_classify,
pooler_for_embed,
)
__all__ = [
"SequencePoolerHead",
"SequencePoolerHeadOutput",
"ClassifierPoolerHead",
"EmbeddingPoolerHead",
"SequencePoolingMethod",
"SequencePoolingMethodOutput",
"CLSPool",
"LastPool",
"MeanPool",
"get_seq_pooling_method",
"SequencePooler",
"SequencePoolingFn",
"SequencePoolingHeadFn",
"SequencePoolerOutput",
"pooler_for_classify",
"pooler_for_embed",
]

View File

@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .methods import SequencePoolingMethodOutput
SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePoolerHead(nn.Module, ABC):
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
@abstractmethod
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
raise NotImplementedError
class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
else:
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
# for normalize
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.classifier = classifier
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, num_labels]
return pooled_data

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config.pooler import SequencePoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
raise NotImplementedError
class CLSPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
class MeanPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
hidden_states.device, non_blocking=True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "LAST":
return LastPool()
if pooling_type == "MEAN":
return MeanPool()
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Set
from typing import TypeAlias
import torch
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
SequencePoolerHead,
SequencePoolerHeadOutput,
)
from .methods import (
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
SequencePoolingFn: TypeAlias = Callable[
[torch.Tensor, PoolingMetadata],
SequencePoolingMethodOutput,
]
SequencePoolingHeadFn: TypeAlias = Callable[
[SequencePoolingMethodOutput, PoolingMetadata],
SequencePoolerHeadOutput,
]
SequencePoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePooler(Pooler):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def __init__(
self,
pooling: SequencePoolingMethod | SequencePoolingFn,
head: SequencePoolerHead | SequencePoolingHeadFn,
) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
tasks = set(POOLING_TASKS)
if isinstance(self.pooling, SequencePoolingMethod):
tasks &= self.pooling.get_supported_tasks()
if isinstance(self.head, SequencePoolerHead):
tasks &= self.head.get_supported_tasks()
return tasks
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
updates = PoolingParamsUpdate()
if isinstance(self.pooling, SequencePoolingMethod):
updates |= self.pooling.get_pooling_updates(task)
return updates
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = EmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
activation=PoolerNormalize(),
)
return SequencePooler(pooling=pooling, head=head)
def pooler_for_classify(
pooler_config: PoolerConfig,
*,
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
)
return SequencePooler(pooling=pooling, head=head)

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping, Set
from itertools import groupby
import torch
from vllm.config import PoolerConfig
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .abstract import Pooler, PoolerOutput
from .common import ClassifierFn
from .seqwise import (
SequencePoolingFn,
SequencePoolingMethod,
pooler_for_classify,
pooler_for_embed,
)
from .tokwise import AllPool, pooler_for_token_classify, pooler_for_token_embed
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
@classmethod
def for_embedding(cls, pooler_config: PoolerConfig):
return cls(
{
"token_embed": pooler_for_token_embed(pooler_config),
"embed": pooler_for_embed(pooler_config),
},
)
@classmethod
def for_seq_cls(
cls,
pooler_config: PoolerConfig,
*,
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
classifier: ClassifierFn | None = None,
):
return cls(
{
"token_classify": pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=classifier,
),
"classify": pooler_for_classify(
pooler_config,
pooling=pooling,
classifier=classifier,
act_fn="classify",
),
"score": pooler_for_classify(
pooler_config,
pooling=pooling,
classifier=classifier,
act_fn="score",
),
}
)
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. "
f"Supported tasks: {pooler.get_supported_tasks()}"
)
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor | None]()
offset = 0
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task!r} "
f"Supported tasks: {self.get_supported_tasks()}"
)
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states,
pooling_metadata[offset : offset + num_items],
)
outputs.extend(group_output)
offset += num_items
return outputs
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s
class IdentityPooler(Pooler):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"plugin", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
class BOSEOSFilter(Pooler):
"""Filters the BOS and EOS token results from outputs."""
def __init__(
self,
pooler: Pooler,
bos_token_id: int = -1, # -1 disables the filtering
eos_token_id: int = -1,
) -> None:
super().__init__()
self.pooler = pooler
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooler.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze(-1)
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output for each token in the sequence."""
from .heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
TokenPoolerHead,
TokenPoolerHeadOutputItem,
)
from .methods import (
AllPool,
StepPool,
TokenPoolingMethod,
TokenPoolingMethodOutputItem,
get_tok_pooling_method,
)
from .poolers import (
TokenPooler,
TokenPoolerOutput,
pooler_for_token_classify,
pooler_for_token_embed,
)
__all__ = [
"TokenPoolerHead",
"TokenPoolerHeadOutputItem",
"TokenClassifierPoolerHead",
"TokenEmbeddingPoolerHead",
"TokenPoolingMethod",
"TokenPoolingMethodOutputItem",
"AllPool",
"StepPool",
"get_tok_pooling_method",
"TokenPooler",
"TokenPoolerOutput",
"pooler_for_token_classify",
"pooler_for_token_embed",
]

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .methods import TokenPoolingMethodOutputItem
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
class TokenPoolerHead(nn.Module, ABC):
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
@abstractmethod
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
raise NotImplementedError
def forward(
self,
pooled_data: list[TokenPoolingMethodOutputItem],
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolerHeadOutputItem]:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(
self,
head_dtype: torch.dtype | str | None = None,
projector: ProjectorFn | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.head_dtype = head_dtype
self.projector = projector
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if self.activation is not None and pooling_param.use_activation:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()
self.classifier = classifier
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(pooled_data)
else:
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if self.activation is not None and pooling_param.use_activation:
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config.pooler import TokenPoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
class TokenPoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
raise NotImplementedError
class AllPool(TokenPoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
scheduler_config = vllm_config.scheduler_config
self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
pooling_cursor = pooling_metadata.get_pooling_cursor()
hidden_states_all = hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()
)
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[TokenPoolingMethodOutputItem]()
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class StepPool(AllPool):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
pooled_data_lst = super().forward(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params
pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pass
else:
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
if pooling_type == "ALL":
return AllPool()
if pooling_type == "STEP":
return StepPool()
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")

View File

@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Set
from typing import TypeAlias
import torch
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import (
ClassifierFn,
PoolingParamsUpdate,
ProjectorFn,
)
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
TokenPoolerHead,
TokenPoolerHeadOutputItem,
)
from .methods import (
TokenPoolingMethod,
TokenPoolingMethodOutputItem,
get_tok_pooling_method,
)
TokenPoolingFn: TypeAlias = Callable[
[torch.Tensor, PoolingMetadata],
list[TokenPoolingMethodOutputItem],
]
TokenPoolingHeadFn: TypeAlias = Callable[
[list[TokenPoolingMethodOutputItem], PoolingMetadata],
list[TokenPoolerHeadOutputItem],
]
TokenPoolerOutput: TypeAlias = list[torch.Tensor | None]
class TokenPooler(Pooler):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def __init__(
self,
pooling: TokenPoolingMethod | TokenPoolingFn,
head: TokenPoolerHead | TokenPoolingHeadFn,
) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
tasks = set(POOLING_TASKS)
if isinstance(self.pooling, TokenPoolingMethod):
tasks &= self.pooling.get_supported_tasks()
if isinstance(self.head, TokenPoolerHead):
tasks &= self.head.get_supported_tasks()
return tasks
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
updates = PoolingParamsUpdate()
if isinstance(self.pooling, TokenPoolingMethod):
updates |= self.pooling.get_pooling_updates(task)
return updates
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
def pooler_for_token_embed(
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
) -> TokenPooler:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
projector=projector
if projector is not None
else _load_st_projector(model_config),
activation=PoolerNormalize(),
)
return TokenPooler(pooling=pooling, head=head)
def pooler_for_token_classify(
pooler_config: PoolerConfig,
*,
pooling: TokenPoolingMethod | TokenPoolingFn | None = None,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
)
return TokenPooler(pooling=pooling, head=head)