123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Set
|
|
from typing import TypeAlias
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.config.pooler import TokenPoolingType
|
|
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
|
|
TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
|
|
|
|
|
class TokenPoolingMethod(nn.Module, ABC):
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return {"token_embed", "token_classify"}
|
|
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return PoolingParamsUpdate()
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> list[TokenPoolingMethodOutputItem]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class AllPool(TokenPoolingMethod):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
scheduler_config = vllm_config.scheduler_config
|
|
|
|
self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> list[TokenPoolingMethodOutputItem]:
|
|
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
|
hidden_states_all = hidden_states.split(
|
|
pooling_cursor.num_scheduled_tokens_cpu.tolist()
|
|
)
|
|
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
|
|
|
|
if not self.enable_chunked_prefill:
|
|
return hidden_states_lst
|
|
|
|
pooling_states = pooling_metadata.pooling_states
|
|
|
|
# If chunked_prefill is enabled
|
|
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
|
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
|
p.hidden_states_cache.append(hs_chunk)
|
|
|
|
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
|
output_list = list[TokenPoolingMethodOutputItem]()
|
|
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
|
|
if finished:
|
|
hidden_states_cache = p.hidden_states_cache
|
|
if len(hidden_states_cache) == 1:
|
|
output_list.append(hidden_states_cache[0])
|
|
else:
|
|
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
|
p.clean()
|
|
else:
|
|
output_list.append(None)
|
|
|
|
return output_list
|
|
|
|
|
|
class StepPool(AllPool):
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return PoolingParamsUpdate(requires_token_ids=True)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> list[TokenPoolingMethodOutputItem]:
|
|
pooled_data_lst = super().forward(hidden_states, pooling_metadata)
|
|
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
|
pooling_params = pooling_metadata.pooling_params
|
|
|
|
pooled_data = list[torch.Tensor | None]()
|
|
for data, token_id, pooling_param in zip(
|
|
pooled_data_lst, prompt_token_ids, pooling_params
|
|
):
|
|
# for unfinished chunked prefill
|
|
if data is None:
|
|
pass
|
|
else:
|
|
step_tag_id = pooling_param.step_tag_id
|
|
returned_token_ids = pooling_param.returned_token_ids
|
|
|
|
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
|
data = data[:, returned_token_ids]
|
|
|
|
if step_tag_id is not None:
|
|
data = data[token_id == step_tag_id]
|
|
|
|
pooled_data.append(data)
|
|
|
|
return pooled_data
|
|
|
|
|
|
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
|
|
if pooling_type == "ALL":
|
|
return AllPool()
|
|
if pooling_type == "STEP":
|
|
return StepPool()
|
|
|
|
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
|