Fix select (#64)
This commit is contained in:
@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
|
||||
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
torch.cuda.synchronize()
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user