46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import cast
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
|
from vllm.v1.worker.gpu.states import RequestState
|
|
|
|
|
|
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
|
# on decoder-only models. How to support other pooling tasks and models
|
|
# is to be determined.
|
|
class PoolingRunner:
|
|
def __init__(self, model: nn.Module):
|
|
self.model = cast(VllmModelForPooling, model)
|
|
|
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
|
if not is_pooling_model(self.model):
|
|
return []
|
|
assert "embed" in self.model.pooler.get_supported_tasks()
|
|
return ["embed"]
|
|
|
|
def pool(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_batch: InputBatch,
|
|
req_states: RequestState,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
# TODO(woosuk): Support different types of pooling tasks.
|
|
last_hidden_states = hidden_states[input_batch.logits_indices]
|
|
# TODO(woosuk): Make normalization optional.
|
|
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
|
|
|
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
|
is_valid = input_batch.seq_lens == prompt_len
|
|
return last_hidden_states, is_valid
|
|
|
|
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
|
F.normalize(hidden_states, p=2, dim=-1)
|
|
return
|