初始化项目,由ModelHub XC社区提供模型
Model: transformers-community/group-beam-search Source: Original Platform
This commit is contained in:
501
custom_generate/beam_search.py
Normal file
501
custom_generate/beam_search.py
Normal file
@@ -0,0 +1,501 @@
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.utils import add_start_docstrings
|
||||
|
||||
PROCESS_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
||||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
||||
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
||||
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`Union[int, list[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
beam_indices (`torch.LongTensor`, *optional*):
|
||||
Beam indices indicating to which beam hypothesis each token correspond.
|
||||
group_index (`int`, *optional*):
|
||||
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
|
||||
|
||||
Return:
|
||||
`UserDict`: A dictionary composed of the fields as defined above:
|
||||
|
||||
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
|
||||
non-finished beams.
|
||||
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
|
||||
to the non-finished beam_hypotheses.
|
||||
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
||||
indicating to which beam the next tokens shall be added.
|
||||
|
||||
"""
|
||||
|
||||
FINALIZE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
||||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
||||
The final scores of all non-finished beams.
|
||||
final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
||||
The last tokens to be added to the non-finished beam_hypotheses.
|
||||
final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
||||
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`Union[int, list[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
||||
The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
|
||||
due to the `eos_token_id`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
if not isinstance(self.early_stopping, bool) and self.max_length is None:
|
||||
raise ValueError(
|
||||
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
|
||||
" BeamScorer class instance at initialization time."
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(
|
||||
self,
|
||||
hyp: torch.LongTensor,
|
||||
sum_logprobs: float,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
generated_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
if generated_len is not None:
|
||||
score = sum_logprobs / (generated_len**self.length_penalty)
|
||||
# This 'else' case exists for retrocompatibility
|
||||
else:
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp, beam_indices))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
||||
del self.beams[sorted_next_scores[0][1]]
|
||||
self.worst_score = sorted_next_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
|
||||
# `True`: stop as soon as at least `num_beams` hypotheses are finished
|
||||
if self.early_stopping is True:
|
||||
return True
|
||||
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
|
||||
# when `length_penalty` is positive. See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
elif self.early_stopping is False:
|
||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
||||
else:
|
||||
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
|
||||
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
||||
# its max this way
|
||||
if self.length_penalty > 0.0:
|
||||
if self.max_length <= decoder_prompt_len:
|
||||
raise ValueError("max_length is not larger than decoder prompt length")
|
||||
highest_attainable_score = (
|
||||
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
||||
)
|
||||
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
||||
else:
|
||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
class BeamScorer(ABC):
|
||||
"""
|
||||
Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
|
||||
[`~PreTrainedModel.beam_sample`].
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
|
||||
@abstractmethod
|
||||
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
**kwargs,
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
|
||||
|
||||
class BeamSearchScorer(BeamScorer):
|
||||
r"""
|
||||
[`BeamScorer`] implementing standard beam search decoding.
|
||||
|
||||
Adapted in part from [Facebook's XLM beam search
|
||||
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
|
||||
|
||||
Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
|
||||
implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search.
|
||||
device (`torch.device`):
|
||||
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
||||
allocated.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
||||
`length_penalty` < 0.0 encourages shorter sequences.
|
||||
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
||||
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
||||
beam search algorithm).
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformers.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`, *optional*, defaults to 1):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://huggingface.co/papers/1610.02424) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
The maximum length of the sequence to be generated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[Union[bool, str]] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
self.num_beam_groups = num_beam_groups
|
||||
self.group_size = self.num_beams // self.num_beam_groups
|
||||
|
||||
self._is_init = False
|
||||
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
|
||||
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
|
||||
self._beam_hyps = [
|
||||
BeamHypotheses(
|
||||
num_beams=self.group_size,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
max_length=max_length,
|
||||
)
|
||||
for _ in range(batch_size * self.num_beam_groups)
|
||||
]
|
||||
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
|
||||
# in the i-th mini-batch is complete.
|
||||
self._done = torch.tensor(
|
||||
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
||||
" one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
||||
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
group_index: Optional[int] = 0,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if batch_size != (input_ids.shape[0] // self.group_size):
|
||||
if self.num_beam_groups > 1:
|
||||
raise ValueError(
|
||||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
||||
f"size of {self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
||||
f"{self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
|
||||
device = input_ids.device
|
||||
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
||||
if self._done[batch_group_idx]:
|
||||
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
|
||||
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
||||
if eos_token_id is None or pad_token_id is None:
|
||||
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
||||
# pad the batch
|
||||
next_beam_scores[batch_idx, :] = 0
|
||||
next_beam_tokens[batch_idx, :] = pad_token_id
|
||||
next_beam_indices[batch_idx, :] = 0
|
||||
continue
|
||||
|
||||
# next tokens for this sentence
|
||||
beam_idx = 0
|
||||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
if beam_indices is not None:
|
||||
beam_index = beam_indices[batch_beam_idx]
|
||||
beam_index = beam_index + (batch_beam_idx,)
|
||||
else:
|
||||
beam_index = None
|
||||
|
||||
self._beam_hyps[batch_group_idx].add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
generated_len=cur_len - decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beam_scores[batch_idx, beam_idx] = next_score
|
||||
next_beam_tokens[batch_idx, beam_idx] = next_token
|
||||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
||||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.group_size:
|
||||
break
|
||||
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
||||
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"next_beam_scores": next_beam_scores.view(-1),
|
||||
"next_beam_tokens": next_beam_tokens.view(-1),
|
||||
"next_beam_indices": next_beam_indices.view(-1),
|
||||
}
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
final_beam_scores: torch.FloatTensor,
|
||||
final_beam_tokens: torch.LongTensor,
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_group_idx]:
|
||||
continue
|
||||
|
||||
# all open beam hypotheses are added to the beam hypothesis
|
||||
# beam hypothesis class automatically keeps the best beams
|
||||
for index_per_group in range(self.group_size):
|
||||
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_indices = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i in range(batch_size):
|
||||
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
|
||||
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
|
||||
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
best_index = best_hyp_tuple[2]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append hyp to lists
|
||||
best.append(best_hyp)
|
||||
|
||||
# append indices to list
|
||||
best_indices.append(best_index)
|
||||
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_lengths_max = sent_lengths.max().item() + 1
|
||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
|
||||
if len(best_indices) > 0 and best_indices[0] is not None:
|
||||
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
else:
|
||||
indices = None
|
||||
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
if pad_token_id is None:
|
||||
raise ValueError("`pad_token_id` has to be defined")
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
if indices is not None:
|
||||
indices.fill_(-1)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
|
||||
if indices is not None:
|
||||
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
# inserting only the first eos_token_id
|
||||
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
"beam_indices": indices,
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user