530 lines
20 KiB
Python
530 lines
20 KiB
Python
|
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||
|
|
#
|
||
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||
|
|
# and associated documentation files (the "Software"), to deal in the Software without
|
||
|
|
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
|
||
|
|
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
|
||
|
|
# Software is furnished to do so, subject to the following conditions:
|
||
|
|
#
|
||
|
|
# The above copyright notice and this permission notice shall be included in all copies or
|
||
|
|
# substantial portions of the Software.
|
||
|
|
#
|
||
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||
|
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||
|
|
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||
|
|
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||
|
|
# OTHER DEALINGS IN THE SOFTWARE.
|
||
|
|
from __future__ import annotations
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
from torch.nn import functional as F
|
||
|
|
from transformers import (
|
||
|
|
PreTrainedModel,
|
||
|
|
BertModel,
|
||
|
|
AutoTokenizer,
|
||
|
|
)
|
||
|
|
import os
|
||
|
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||
|
|
from typing import Union, List, Optional
|
||
|
|
from collections import defaultdict
|
||
|
|
import numpy as np
|
||
|
|
import math
|
||
|
|
from huggingface_hub import hf_hub_download
|
||
|
|
from .configuration_listconranker import ListConRankerConfig
|
||
|
|
|
||
|
|
|
||
|
|
class QueryEmbedding(nn.Module):
|
||
|
|
def __init__(self, config) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.query_embedding = nn.Embedding(2, config.list_con_hidden_size)
|
||
|
|
self.layerNorm = nn.LayerNorm(config.list_con_hidden_size)
|
||
|
|
|
||
|
|
def forward(self, x, tags):
|
||
|
|
query_embeddings = self.query_embedding(tags)
|
||
|
|
x += query_embeddings
|
||
|
|
x = self.layerNorm(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class ListTransformer(nn.Module):
|
||
|
|
def __init__(self, num_layer, config) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.config = config
|
||
|
|
self.list_transformer_layer = nn.TransformerEncoderLayer(
|
||
|
|
config.list_con_hidden_size,
|
||
|
|
self.config.num_attention_heads,
|
||
|
|
batch_first=True,
|
||
|
|
activation=F.gelu,
|
||
|
|
norm_first=False,
|
||
|
|
)
|
||
|
|
self.list_transformer = nn.TransformerEncoder(
|
||
|
|
self.list_transformer_layer, num_layer
|
||
|
|
)
|
||
|
|
self.relu = nn.ReLU()
|
||
|
|
self.query_embedding = QueryEmbedding(config)
|
||
|
|
|
||
|
|
self.linear_score3 = nn.Linear(
|
||
|
|
config.list_con_hidden_size * 2, config.list_con_hidden_size
|
||
|
|
)
|
||
|
|
self.linear_score2 = nn.Linear(
|
||
|
|
config.list_con_hidden_size * 2, config.list_con_hidden_size
|
||
|
|
)
|
||
|
|
self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self, pair_features: torch.Tensor, pair_nums: List[int]
|
||
|
|
) -> torch.Tensor:
|
||
|
|
batch_pair_features = pair_features.split(pair_nums)
|
||
|
|
|
||
|
|
pair_feature_query_passage_concat_list = []
|
||
|
|
for i in range(len(batch_pair_features)):
|
||
|
|
pair_feature_query = (
|
||
|
|
batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
|
||
|
|
)
|
||
|
|
pair_feature_passage = batch_pair_features[i][1:]
|
||
|
|
pair_feature_query_passage_concat_list.append(
|
||
|
|
torch.cat([pair_feature_query, pair_feature_passage], dim=1)
|
||
|
|
)
|
||
|
|
pair_feature_query_passage_concat = torch.cat(
|
||
|
|
pair_feature_query_passage_concat_list, dim=0
|
||
|
|
)
|
||
|
|
|
||
|
|
batch_pair_features = nn.utils.rnn.pad_sequence(
|
||
|
|
batch_pair_features, batch_first=True
|
||
|
|
)
|
||
|
|
|
||
|
|
query_embedding_tags = torch.zeros(
|
||
|
|
batch_pair_features.size(0),
|
||
|
|
batch_pair_features.size(1),
|
||
|
|
dtype=torch.long,
|
||
|
|
device=self.device,
|
||
|
|
)
|
||
|
|
query_embedding_tags[:, 0] = 1
|
||
|
|
batch_pair_features = self.query_embedding(
|
||
|
|
batch_pair_features, query_embedding_tags
|
||
|
|
)
|
||
|
|
|
||
|
|
mask = self.generate_attention_mask(pair_nums)
|
||
|
|
query_mask = self.generate_attention_mask_custom(pair_nums)
|
||
|
|
pair_list_features = self.list_transformer(
|
||
|
|
batch_pair_features, src_key_padding_mask=mask, mask=query_mask
|
||
|
|
)
|
||
|
|
|
||
|
|
output_pair_list_features = []
|
||
|
|
output_query_list_features = []
|
||
|
|
pair_features_after_transformer_list = []
|
||
|
|
for idx, pair_num in enumerate(pair_nums):
|
||
|
|
output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
|
||
|
|
output_query_list_features.append(pair_list_features[idx, 0, :])
|
||
|
|
pair_features_after_transformer_list.append(
|
||
|
|
pair_list_features[idx, :pair_num, :]
|
||
|
|
)
|
||
|
|
|
||
|
|
pair_features_after_transformer_cat_query_list = []
|
||
|
|
for idx, pair_num in enumerate(pair_nums):
|
||
|
|
query_ft = (
|
||
|
|
output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
|
||
|
|
)
|
||
|
|
pair_features_after_transformer_cat_query = torch.cat(
|
||
|
|
[query_ft, output_pair_list_features[idx]], dim=1
|
||
|
|
)
|
||
|
|
pair_features_after_transformer_cat_query_list.append(
|
||
|
|
pair_features_after_transformer_cat_query
|
||
|
|
)
|
||
|
|
pair_features_after_transformer_cat_query = torch.cat(
|
||
|
|
pair_features_after_transformer_cat_query_list, dim=0
|
||
|
|
)
|
||
|
|
|
||
|
|
pair_feature_query_passage_concat = self.relu(
|
||
|
|
self.linear_score2(pair_feature_query_passage_concat)
|
||
|
|
)
|
||
|
|
pair_features_after_transformer_cat_query = self.relu(
|
||
|
|
self.linear_score3(pair_features_after_transformer_cat_query)
|
||
|
|
)
|
||
|
|
final_ft = torch.cat(
|
||
|
|
[
|
||
|
|
pair_feature_query_passage_concat,
|
||
|
|
pair_features_after_transformer_cat_query,
|
||
|
|
],
|
||
|
|
dim=1,
|
||
|
|
)
|
||
|
|
logits = self.linear_score1(final_ft).squeeze()
|
||
|
|
return logits, torch.cat(pair_features_after_transformer_list, dim=0)
|
||
|
|
|
||
|
|
def generate_attention_mask(self, pair_num):
|
||
|
|
max_len = max(pair_num)
|
||
|
|
batch_size = len(pair_num)
|
||
|
|
mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
|
||
|
|
for i, length in enumerate(pair_num):
|
||
|
|
mask[i, length:] = True
|
||
|
|
return mask
|
||
|
|
|
||
|
|
def generate_attention_mask_custom(self, pair_num):
|
||
|
|
max_len = max(pair_num)
|
||
|
|
mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
|
||
|
|
mask[0, 1:] = True
|
||
|
|
return mask
|
||
|
|
|
||
|
|
|
||
|
|
class ListConRankerModel(PreTrainedModel):
|
||
|
|
"""
|
||
|
|
ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
|
||
|
|
"""
|
||
|
|
|
||
|
|
config_class = ListConRankerConfig
|
||
|
|
base_model_prefix = "listconranker"
|
||
|
|
|
||
|
|
def __init__(self, config: ListConRankerConfig):
|
||
|
|
super().__init__(config)
|
||
|
|
self.config = config
|
||
|
|
self.num_labels = config.num_labels
|
||
|
|
self.hf_model = BertModel(config.bert_config)
|
||
|
|
|
||
|
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
|
||
|
|
self.linear_in_embedding = nn.Linear(
|
||
|
|
config.hidden_size, config.list_con_hidden_size
|
||
|
|
)
|
||
|
|
self.list_transformer = ListTransformer(
|
||
|
|
config.list_transformer_layers,
|
||
|
|
config,
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
input_ids: torch.Tensor,
|
||
|
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
||
|
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
|
labels: Optional[torch.Tensor] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
||
|
|
if self.training:
|
||
|
|
raise NotImplementedError("Training not supported; use eval mode.")
|
||
|
|
device = input_ids.device
|
||
|
|
self.list_transformer.device = device
|
||
|
|
# Reorganize by unique queries and their passages
|
||
|
|
(
|
||
|
|
reorganized_input_ids,
|
||
|
|
reorganized_attention_mask,
|
||
|
|
reorganized_token_type_ids,
|
||
|
|
pair_nums,
|
||
|
|
group_indices,
|
||
|
|
) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids)
|
||
|
|
|
||
|
|
out = self.hf_model(
|
||
|
|
input_ids=reorganized_input_ids,
|
||
|
|
attention_mask=reorganized_attention_mask,
|
||
|
|
token_type_ids=reorganized_token_type_ids,
|
||
|
|
return_dict=True,
|
||
|
|
)
|
||
|
|
feats = out.last_hidden_state
|
||
|
|
pooled = self.average_pooling(feats, reorganized_attention_mask)
|
||
|
|
embedded = self.linear_in_embedding(pooled)
|
||
|
|
logits, _ = self.list_transformer(embedded, pair_nums)
|
||
|
|
probs = self.sigmoid(logits)
|
||
|
|
|
||
|
|
# Restore original order
|
||
|
|
sorted_probs = self._restore_original_order(probs, group_indices)
|
||
|
|
sorted_logits = self._restore_original_order(logits, group_indices)
|
||
|
|
if not return_dict:
|
||
|
|
return (sorted_probs, sorted_logits)
|
||
|
|
|
||
|
|
return SequenceClassifierOutput(
|
||
|
|
loss=None,
|
||
|
|
logits=sorted_logits,
|
||
|
|
hidden_states=out.hidden_states,
|
||
|
|
attentions=out.attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _reorganize_inputs(
|
||
|
|
self,
|
||
|
|
input_ids: torch.Tensor,
|
||
|
|
attention_mask: torch.Tensor,
|
||
|
|
token_type_ids: Optional[torch.Tensor],
|
||
|
|
) -> tuple[
|
||
|
|
torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]]
|
||
|
|
]:
|
||
|
|
"""
|
||
|
|
Group inputs by unique queries: for each query, produce [query] + its passages,
|
||
|
|
then flatten, pad, and return pair sizes and original indices mapping.
|
||
|
|
"""
|
||
|
|
batch_size = input_ids.size(0)
|
||
|
|
# Structure: query_key -> {
|
||
|
|
# 'query': (seq, mask, tt),
|
||
|
|
# 'passages': [(seq, mask, tt), ...],
|
||
|
|
# 'indices': [original_index, ...]
|
||
|
|
# }
|
||
|
|
grouped = {}
|
||
|
|
|
||
|
|
for idx in range(batch_size):
|
||
|
|
seq = input_ids[idx]
|
||
|
|
mask = attention_mask[idx]
|
||
|
|
token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq)
|
||
|
|
|
||
|
|
sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0]
|
||
|
|
if sep_idxs.numel() == 0:
|
||
|
|
raise ValueError(f"No SEP in sequence {idx}")
|
||
|
|
first_sep = sep_idxs[0].item()
|
||
|
|
second_sep = sep_idxs[1].item()
|
||
|
|
|
||
|
|
# Extract query and passage
|
||
|
|
q_seq = seq[: first_sep + 1]
|
||
|
|
q_mask = mask[: first_sep + 1]
|
||
|
|
q_tt = torch.zeros_like(q_seq)
|
||
|
|
|
||
|
|
p_seq = seq[first_sep : second_sep + 1]
|
||
|
|
p_mask = mask[first_sep : second_sep + 1]
|
||
|
|
p_seq = p_seq.clone()
|
||
|
|
p_seq[0] = self.config.cls_token_id
|
||
|
|
p_tt = torch.zeros_like(p_seq)
|
||
|
|
|
||
|
|
# Build key excluding CLS/SEP
|
||
|
|
key = tuple(
|
||
|
|
q_seq[
|
||
|
|
(q_seq != self.config.cls_token_id)
|
||
|
|
& (q_seq != self.config.sep_token_id)
|
||
|
|
].tolist()
|
||
|
|
)
|
||
|
|
|
||
|
|
# truncation
|
||
|
|
q_seq = q_seq[: self.config.max_position_embeddings]
|
||
|
|
q_seq[-1] = self.config.sep_token_id
|
||
|
|
p_seq = p_seq[: self.config.max_position_embeddings]
|
||
|
|
p_seq[-1] = self.config.sep_token_id
|
||
|
|
q_mask = q_mask[: self.config.max_position_embeddings]
|
||
|
|
p_mask = p_mask[: self.config.max_position_embeddings]
|
||
|
|
q_tt = q_tt[: self.config.max_position_embeddings]
|
||
|
|
p_tt = p_tt[: self.config.max_position_embeddings]
|
||
|
|
|
||
|
|
if key not in grouped:
|
||
|
|
grouped[key] = {
|
||
|
|
"query": (q_seq, q_mask, q_tt),
|
||
|
|
"passages": [],
|
||
|
|
"indices": [],
|
||
|
|
}
|
||
|
|
grouped[key]["passages"].append((p_seq, p_mask, p_tt))
|
||
|
|
grouped[key]["indices"].append(idx)
|
||
|
|
|
||
|
|
# Flatten according to group insertion order
|
||
|
|
seqs, masks, tts, pair_nums, group_indices = [], [], [], [], []
|
||
|
|
for key, data in grouped.items():
|
||
|
|
q_seq, q_mask, q_tt = data["query"]
|
||
|
|
passages = data["passages"]
|
||
|
|
indices = data["indices"]
|
||
|
|
# record sizes and original positions
|
||
|
|
pair_nums.append(len(passages) + 1) # +1 for the query
|
||
|
|
group_indices.append(indices)
|
||
|
|
|
||
|
|
# append query then its passages
|
||
|
|
seqs.append(q_seq)
|
||
|
|
masks.append(q_mask)
|
||
|
|
tts.append(q_tt)
|
||
|
|
for p_seq, p_mask, p_tt in passages:
|
||
|
|
seqs.append(p_seq)
|
||
|
|
masks.append(p_mask)
|
||
|
|
tts.append(p_tt)
|
||
|
|
|
||
|
|
# Pad to uniform length
|
||
|
|
max_len = max(s.size(0) for s in seqs)
|
||
|
|
padded_seqs, padded_masks, padded_tts = [], [], []
|
||
|
|
for s, m, t in zip(seqs, masks, tts):
|
||
|
|
ps = torch.zeros(max_len, dtype=s.dtype, device=s.device)
|
||
|
|
pm = torch.zeros(max_len, dtype=m.dtype, device=m.device)
|
||
|
|
pt = torch.zeros(max_len, dtype=t.dtype, device=t.device)
|
||
|
|
ps[: s.size(0)] = s
|
||
|
|
pm[: m.size(0)] = m
|
||
|
|
pt[: t.size(0)] = t
|
||
|
|
padded_seqs.append(ps)
|
||
|
|
padded_masks.append(pm)
|
||
|
|
padded_tts.append(pt)
|
||
|
|
|
||
|
|
rid = torch.stack(padded_seqs)
|
||
|
|
ram = torch.stack(padded_masks)
|
||
|
|
rtt = torch.stack(padded_tts) if token_type_ids is not None else None
|
||
|
|
|
||
|
|
return rid, ram, rtt, pair_nums, group_indices
|
||
|
|
|
||
|
|
def _restore_original_order(
|
||
|
|
self,
|
||
|
|
logits: torch.Tensor,
|
||
|
|
group_indices: List[List[int]],
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
Map flattened logits back so each original index gets its passage score.
|
||
|
|
"""
|
||
|
|
out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device)
|
||
|
|
i = 0
|
||
|
|
for indices in group_indices:
|
||
|
|
for idx in indices:
|
||
|
|
out[idx] = logits[i]
|
||
|
|
i += 1
|
||
|
|
return out.reshape(-1, 1)
|
||
|
|
|
||
|
|
def average_pooling(self, hidden_state, attention_mask):
|
||
|
|
extended_attention_mask = (
|
||
|
|
attention_mask.unsqueeze(-1)
|
||
|
|
.expand(hidden_state.size())
|
||
|
|
.to(dtype=hidden_state.dtype)
|
||
|
|
)
|
||
|
|
masked_hidden_state = hidden_state * extended_attention_mask
|
||
|
|
sum_embeddings = torch.sum(masked_hidden_state, dim=1)
|
||
|
|
sum_mask = extended_attention_mask.sum(dim=1)
|
||
|
|
return sum_embeddings / sum_mask
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_pretrained(
|
||
|
|
cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
|
||
|
|
):
|
||
|
|
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
||
|
|
model.hf_model = BertModel.from_pretrained(
|
||
|
|
model_name_or_path, config=model.config.bert_config, **kwargs
|
||
|
|
)
|
||
|
|
linear_path = hf_hub_download(
|
||
|
|
repo_id = model_name_or_path,
|
||
|
|
filename = "linear_in_embedding.pt",
|
||
|
|
revision = "main",
|
||
|
|
cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None
|
||
|
|
)
|
||
|
|
list_transformer_path = hf_hub_download(
|
||
|
|
repo_id = "ByteDance/ListConRanker",
|
||
|
|
filename = "list_transformer.pt",
|
||
|
|
revision = "main",
|
||
|
|
cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
model.linear_in_embedding.load_state_dict(torch.load(linear_path))
|
||
|
|
model.list_transformer.load_state_dict(torch.load(list_transformer_path))
|
||
|
|
except FileNotFoundError as e:
|
||
|
|
raise e
|
||
|
|
|
||
|
|
return model
|
||
|
|
|
||
|
|
def multi_passage(
|
||
|
|
self,
|
||
|
|
sentences: List[List[str]],
|
||
|
|
batch_size: int = 32,
|
||
|
|
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
"ByteDance/ListConRanker"
|
||
|
|
),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Process multiple passages for each query.
|
||
|
|
:param sentences: List of lists, where each inner list contains sentences for a query.
|
||
|
|
:return: Tensor of logits for each passage.
|
||
|
|
"""
|
||
|
|
pairs = []
|
||
|
|
for batch in sentences:
|
||
|
|
if len(batch) < 2:
|
||
|
|
raise ValueError("Each query must have at least one passage.")
|
||
|
|
query = batch[0]
|
||
|
|
passages = batch[1:]
|
||
|
|
for passage in passages:
|
||
|
|
pairs.append((query, passage))
|
||
|
|
|
||
|
|
total_batches = (len(pairs) + batch_size - 1) // batch_size
|
||
|
|
total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device)
|
||
|
|
for batch in range(total_batches):
|
||
|
|
batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size]
|
||
|
|
inputs = tokenizer(
|
||
|
|
batch_pairs,
|
||
|
|
padding=True,
|
||
|
|
truncation=False,
|
||
|
|
return_tensors="pt",
|
||
|
|
)
|
||
|
|
|
||
|
|
for k, v in inputs.items():
|
||
|
|
inputs[k] = v.to(self.device)
|
||
|
|
|
||
|
|
logits = self(**inputs)[0]
|
||
|
|
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
||
|
|
logits.squeeze(1)
|
||
|
|
)
|
||
|
|
return total_logits.tolist()
|
||
|
|
|
||
|
|
def multi_passage_in_iterative_inference(
|
||
|
|
self,
|
||
|
|
sentences: List[str],
|
||
|
|
stop_num: int = 20,
|
||
|
|
decrement_rate: float = 0.2,
|
||
|
|
min_filter_num: int = 10,
|
||
|
|
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
"ByteDance/ListConRanker"
|
||
|
|
),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Process multiple passages for one query in iterative inference.
|
||
|
|
:param sentences: List contains sentences for a query.
|
||
|
|
:return: Tensor of logits for each passage.
|
||
|
|
"""
|
||
|
|
if stop_num < 1:
|
||
|
|
raise ValueError("stop_num must be greater than 0")
|
||
|
|
if decrement_rate <= 0 or decrement_rate >= 1:
|
||
|
|
raise ValueError("decrement_rate must be in (0, 1)")
|
||
|
|
if min_filter_num < 1:
|
||
|
|
raise ValueError("min_filter_num must be greater than 0")
|
||
|
|
|
||
|
|
query = sentences[0]
|
||
|
|
passage = sentences[1:]
|
||
|
|
|
||
|
|
filter_times = 0
|
||
|
|
passage2score = defaultdict(list)
|
||
|
|
while len(passage) > stop_num:
|
||
|
|
batch = [[query] + passage]
|
||
|
|
pred_scores = self.multi_passage(
|
||
|
|
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
||
|
|
)
|
||
|
|
pred_scores_argsort = np.argsort(
|
||
|
|
pred_scores
|
||
|
|
).tolist() # Sort in increasing order
|
||
|
|
|
||
|
|
passage_len = len(passage)
|
||
|
|
to_filter_num = math.ceil(passage_len * decrement_rate)
|
||
|
|
if to_filter_num < min_filter_num:
|
||
|
|
to_filter_num = min_filter_num
|
||
|
|
|
||
|
|
have_filter_num = 0
|
||
|
|
while have_filter_num < to_filter_num:
|
||
|
|
idx = pred_scores_argsort[have_filter_num]
|
||
|
|
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||
|
|
have_filter_num += 1
|
||
|
|
while (
|
||
|
|
pred_scores[pred_scores_argsort[have_filter_num - 1]]
|
||
|
|
== pred_scores[pred_scores_argsort[have_filter_num]]
|
||
|
|
):
|
||
|
|
idx = pred_scores_argsort[have_filter_num]
|
||
|
|
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||
|
|
have_filter_num += 1
|
||
|
|
next_passage = []
|
||
|
|
next_passage_idx = have_filter_num
|
||
|
|
while next_passage_idx < len(passage):
|
||
|
|
idx = pred_scores_argsort[next_passage_idx]
|
||
|
|
next_passage.append(passage[idx])
|
||
|
|
next_passage_idx += 1
|
||
|
|
passage = next_passage
|
||
|
|
filter_times += 1
|
||
|
|
|
||
|
|
batch = [[query] + passage]
|
||
|
|
pred_scores = self.multi_passage(
|
||
|
|
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
||
|
|
)
|
||
|
|
|
||
|
|
cnt = 0
|
||
|
|
while cnt < len(passage):
|
||
|
|
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
||
|
|
cnt += 1
|
||
|
|
|
||
|
|
passage = sentences[1:]
|
||
|
|
final_score = []
|
||
|
|
for i in range(len(passage)):
|
||
|
|
p = passage[i]
|
||
|
|
final_score.append(passage2score[p][0])
|
||
|
|
return final_score
|