Optimize the memory usage of logits processor (#420)

This commit is contained in:
Lianmin Zheng
2024-05-11 16:56:42 -07:00
committed by GitHub
parent 33b242df30
commit 09deb20dee
2 changed files with 4 additions and 2 deletions

View File

@@ -98,7 +98,9 @@ class LogitsProcessor(nn.Module):
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
all_logprobs = all_logits.float()
all_logits = None
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata