Files

94 lines
3.0 KiB
Python
Raw Permalink Normal View History

2026-04-09 11:23:47 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config.pooler import SequencePoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
raise NotImplementedError
class CLSPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
class MeanPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
hidden_states.device, non_blocking=True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "LAST":
return LastPool()
if pooling_type == "MEAN":
return MeanPool()
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")