初始化项目,由ModelHub XC社区提供模型
Model: iic/ERank-4B Source: Original Platform
This commit is contained in:
97
examples/ERank_vLLM.py
Normal file
97
examples/ERank_vLLM.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import math
|
||||
from vllm import LLM, SamplingParams
|
||||
from utils import prompt_template, truncate
|
||||
|
||||
|
||||
class ERank_vLLM:
|
||||
|
||||
def __init__(self, model_name_or_path: str):
|
||||
"""
|
||||
Initializes the ERank_vLLM reranker.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The name or path of the model to be loaded.
|
||||
This can be a Hugging Face model ID or a local path.
|
||||
"""
|
||||
num_gpu = torch.cuda.device_count()
|
||||
self.ranker = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=num_gpu,
|
||||
gpu_memory_utilization=0.95,
|
||||
enable_prefix_caching=True
|
||||
)
|
||||
self.tokenizer = self.ranker.get_tokenizer()
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=4096,
|
||||
logprobs=20
|
||||
)
|
||||
|
||||
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
|
||||
"""
|
||||
Reranks a list of documents based on a query and a specific instruction.
|
||||
|
||||
Args:
|
||||
query (str): The search query provided by the user.
|
||||
docs (list): A list of dictionaries, where each dictionary represents a document
|
||||
and must contain a "content" key.
|
||||
instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
|
||||
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
|
||||
"""
|
||||
|
||||
# prepare messages
|
||||
messages = [
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt_template.format(
|
||||
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
|
||||
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
|
||||
instruction=instruction
|
||||
)
|
||||
}] for doc in docs
|
||||
]
|
||||
|
||||
# LLM generate
|
||||
outputs = self.ranker.chat(messages, self.sampling_params)
|
||||
|
||||
# extract and organize results
|
||||
results = []
|
||||
for doc, output in zip(docs, outputs):
|
||||
|
||||
# extract the answer and its probability
|
||||
cur = ""
|
||||
answer = ""
|
||||
is_ans = False
|
||||
prob = 1.0
|
||||
for each in output.outputs[0].logprobs[-10:]:
|
||||
_, detail = next(iter(each.items()))
|
||||
token = detail.decoded_token
|
||||
logprob = detail.logprob
|
||||
if is_ans and token.isdigit():
|
||||
answer += token
|
||||
prob *= math.exp(logprob)
|
||||
else:
|
||||
cur += token
|
||||
if cur.endswith("<answer>"):
|
||||
is_ans = True
|
||||
|
||||
# in case the answer is not a digit or exceeds 10
|
||||
try:
|
||||
answer = int(answer)
|
||||
assert answer <= 10
|
||||
except:
|
||||
answer = -1
|
||||
|
||||
# append to the final results
|
||||
results.append({
|
||||
**doc,
|
||||
"rank_score": answer * prob
|
||||
})
|
||||
|
||||
# sort the reranking results for the query
|
||||
results.sort(key=lambda x:x["rank_score"], reverse=True)
|
||||
return results
|
||||
Reference in New Issue
Block a user