[2/2] Support deterministic inference for temperature > 0 (#10678)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
@@ -2049,7 +2049,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
@@ -2057,6 +2056,12 @@ class ModelRunner:
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
forward_batch.token_ids_logprobs,
|
||||
# For prefill, we only use the position of the last token.
|
||||
(
|
||||
forward_batch.positions
|
||||
if forward_batch.forward_mode.is_decode()
|
||||
else forward_batch.seq_lens - 1
|
||||
),
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user