Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
prefill_token_logprobs=torch.ones_like(input_ids),
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):