Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)

View File

@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
use_gpu_for_pooling_score: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
log_error_stack=log_error_stack,
)
self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2],
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
)
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from typing import Any, TypeAlias, cast
import torch
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content