Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
651
vllm/entrypoints/pooling/score/serving.py
Normal file
651
vllm/entrypoints/pooling/score/serving.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResult,
|
||||
RerankUsage,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
)
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.is_cross_encoder = self.model_config.is_cross_encoder
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_late_interaction = self.model_config.is_late_interaction
|
||||
|
||||
if self.is_cross_encoder:
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.is_late_interaction:
|
||||
self._score_func = self._late_interaction_score
|
||||
else:
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Embedding scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[ProcessorInputs] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
token_inputs(
|
||||
text_token_prompt["prompt_token_ids"],
|
||||
prompt=input_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params("embed")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(
|
||||
tokenizer=tokenizer, embed_1=emb_data_1, embed_2=emb_data_2
|
||||
)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_late_interaction_item(
|
||||
self,
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
"""Parse a single ScoreData into a text + optional multimodal
|
||||
TokensPrompt for late-interaction encoding.
|
||||
|
||||
For plain strings, tokenises directly.
|
||||
For multimodal content parts, extracts text and multi_modal_data.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
|
||||
if isinstance(data, str):
|
||||
text, mm_data, mm_uuids = data, None, None
|
||||
else:
|
||||
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
|
||||
|
||||
prompt_inputs = tokenizer(text, **tokenization_kwargs)
|
||||
self._validate_input(request, prompt_inputs["input_ids"], text)
|
||||
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
)
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return text, engine_prompt
|
||||
|
||||
async def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
|
||||
all_data = data_1 + data_2
|
||||
roles = ["query"] * len(data_1) + ["document"] * len(data_2)
|
||||
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_late_interaction_item,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
preprocessed = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
data=d,
|
||||
role=r,
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
for d, r in zip(all_data, roles)
|
||||
)
|
||||
)
|
||||
|
||||
input_texts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for text, engine_prompt in preprocessed:
|
||||
input_texts.append(text)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params("token_embed")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Collect token embeddings
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
# Split into query and document embeddings
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_score,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2,
|
||||
)
|
||||
for t1, t2 in input_pairs
|
||||
)
|
||||
)
|
||||
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for full_prompt, engine_prompt in preprocessed_prompts:
|
||||
request_prompts.append(full_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params("score")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
request_prompts[i],
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
else:
|
||||
pooling_params = default_pooling_params
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
|
||||
engine_prompts
|
||||
)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
model_config = self.model_config
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=self.score_template,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: ScoreInputs,
|
||||
data_2: ScoreInputs,
|
||||
request: ScoreRequest | RerankRequest,
|
||||
request_id: str,
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
score_data_1, score_data_2 = validate_score_input(
|
||||
data_1,
|
||||
data_2,
|
||||
is_multimodal_model=self.is_multimodal_model,
|
||||
architecture=self.architecture,
|
||||
)
|
||||
|
||||
return await self._score_func(
|
||||
data_1=score_data_1,
|
||||
data_2=score_data_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> ScoreResponse | ErrorResponse:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.data_1,
|
||||
request.data_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def do_rerank(
|
||||
self, request: RerankRequest, raw_request: Request | None = None
|
||||
) -> RerankResponse | ErrorResponse:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
top_n = request.top_n if request.top_n > 0 else len(final_res_batch)
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
documents: ScoreInputs,
|
||||
top_n: int,
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
|
||||
if not isinstance(documents, list):
|
||||
documents = [documents]
|
||||
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
document = documents[idx]
|
||||
if isinstance(document, str):
|
||||
rerank_document = RerankDocument(text=document)
|
||||
else:
|
||||
rerank_document = RerankDocument(
|
||||
multi_modal=document.get("content", [])
|
||||
)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=rerank_document,
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(
|
||||
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user