Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
||||
# on decoder-only models. How to support other pooling tasks and models
|
||||
# is to be determined.
|
||||
class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO(woosuk): Support different types of pooling tasks.
|
||||
last_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
# TODO(woosuk): Make normalization optional.
|
||||
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
||||
|
||||
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
||||
is_valid = input_batch.seq_lens == prompt_len
|
||||
return last_hidden_states, is_valid
|
||||
|
||||
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
F.normalize(hidden_states, p=2, dim=-1)
|
||||
return
|
||||
Reference in New Issue
Block a user