Support decode token logprobs (#130)

This commit is contained in:
Cody Yu
2024-02-06 12:24:55 -08:00
committed by GitHub
parent ee1df26a77
commit a7334aeea1
10 changed files with 233 additions and 96 deletions

View File

@@ -14,28 +14,11 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index]
hidden_states = None
last_index = None
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None)
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
if input_metadata.forward_mode != ForwardMode.DECODE:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
@@ -45,29 +28,54 @@ class LogitsProcessor(nn.Module):
- 1
)
if not input_metadata.return_logprob:
# When logprob is not requested, only compute the last logits.
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index]
hidden_states = None
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None, None)
else:
# When logprob is requested, compute the logits for all tokens.
logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logits = logits
last_logprobs = all_logprobs
prefill_logprobs = normalized_logprobs = None
else:
# Compute the logprobs for the last token of each request.
last_logits = logits[last_index]
last_logprobs = all_logprobs[last_index]
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
last_logits = logits[last_index]
return last_logits, (logprobs, normalized_logprobs)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
if __name__ == "__main__":