Return logprob for choices (#87)
This commit is contained in:
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
if not input_metadata.return_normalized_logprob:
|
||||
if not input_metadata.return_logprob:
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, None
|
||||
return last_logits, (None, None)
|
||||
else:
|
||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||
last_index = (
|
||||
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
normalized_logprobs = compute_normalized_logprobs(
|
||||
all_logprobs,
|
||||
input_ids,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, normalized_logprobs
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = start_loc.clone()
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
return sum_logp / ((seq_lens - 1).clamp(min=1))
|
||||
return last_logits, (logprobs, normalized_logprobs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user