94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Set
|
|
from typing import TypeAlias
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config.pooler import SequencePoolingType
|
|
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
|
|
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
|
|
|
|
|
class SequencePoolingMethod(nn.Module, ABC):
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
|
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return PoolingParamsUpdate()
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> SequencePoolingMethodOutput:
|
|
raise NotImplementedError
|
|
|
|
|
|
class CLSPool(SequencePoolingMethod):
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> SequencePoolingMethodOutput:
|
|
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
|
assert not pooling_cursor.is_partial_prefill(), (
|
|
"partial prefill not supported with CLS pooling"
|
|
)
|
|
|
|
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
|
|
|
|
|
class LastPool(SequencePoolingMethod):
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> SequencePoolingMethodOutput:
|
|
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
|
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
|
|
|
|
|
class MeanPool(SequencePoolingMethod):
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> SequencePoolingMethodOutput:
|
|
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
|
assert not pooling_cursor.is_partial_prefill(), (
|
|
"partial prefill not supported with MEAN pooling"
|
|
)
|
|
|
|
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
|
hidden_states.device, non_blocking=True
|
|
)
|
|
|
|
# Use float32 for torch.cumsum in MeanPool,
|
|
# otherwise precision will be lost significantly.
|
|
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
|
|
|
start_indices = pooling_cursor.first_token_indices_gpu
|
|
end_indices = pooling_cursor.last_token_indices_gpu
|
|
|
|
return (
|
|
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
|
) / prompt_lens.unsqueeze(1)
|
|
|
|
|
|
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
|
|
if pooling_type == "CLS":
|
|
return CLSPool()
|
|
if pooling_type == "LAST":
|
|
return LastPool()
|
|
if pooling_type == "MEAN":
|
|
return MeanPool()
|
|
|
|
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")
|