Optimize the memory usage of logits processor (#420)
This commit is contained in:
@@ -98,7 +98,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
all_logits = all_logits[:, : self.config.vocab_size]
|
all_logits = all_logits[:, : self.config.vocab_size]
|
||||||
|
|
||||||
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
|
all_logprobs = all_logits.float()
|
||||||
|
all_logits = None
|
||||||
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||||
all_logprobs, input_metadata
|
all_logprobs, input_metadata
|
||||||
|
|||||||
@@ -589,7 +589,7 @@ class ModelRpcServer:
|
|||||||
+ len(req.output_ids)
|
+ len(req.output_ids)
|
||||||
- req.prompt_tokens,
|
- req.prompt_tokens,
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": req.finish_reason,
|
"finish_reason": str(req.finish_reason),
|
||||||
"hit_stop_str": req.hit_stop_str,
|
"hit_stop_str": req.hit_stop_str,
|
||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
|
|||||||
Reference in New Issue
Block a user