Format code & move functions (#155)
This commit is contained in:
@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_logprobs[start]
|
||||
)
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user