[2/2] Support deterministic inference for temperature > 0 (#10678)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-09-21 19:36:08 -07:00
committed by GitHub
parent 86527a4799
commit e2ac7888b8
12 changed files with 117 additions and 11 deletions

View File

@@ -2049,7 +2049,6 @@ class ModelRunner:
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens
next_token_ids = self.sampler(
logits_output,
@@ -2057,6 +2056,12 @@ class ModelRunner:
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
# For prefill, we only use the position of the last token.
(
forward_batch.positions
if forward_batch.forward_mode.is_decode()
else forward_batch.seq_lens - 1
),
)
return next_token_ids