Fix select (#64)

This commit is contained in:
Lianmin Zheng
2024-01-20 23:20:35 -08:00
committed by GitHub
parent ca13f3b8c5
commit 11f3cca64f
2 changed files with 12 additions and 2 deletions

View File

@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
)
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1)
torch.cuda.synchronize()
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1
return res