update
This commit is contained in:
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.pooler import SequencePoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolingMethod(nn.Module, ABC):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
|
||||
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||
|
||||
|
||||
class LastPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class MeanPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
|
||||
return (
|
||||
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
|
||||
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")
|
||||
Reference in New Issue
Block a user