Format code & move functions (#155)

This commit is contained in:
Lianmin Zheng
2024-02-06 13:27:46 -08:00
committed by GitHub
parent a7334aeea1
commit 23f05005fd
14 changed files with 94 additions and 54 deletions

View File

@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
logprobs_cumsum = torch.cumsum(
prefill_logprobs, dim=0, dtype=torch.float32
)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_logprobs[start]
)
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)