Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -931,14 +931,14 @@ class Scheduler:
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if req.is_being_chunked <= 0:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
@@ -947,7 +947,7 @@ class Scheduler:
self.tree_cache.cache_unfinished_req(req)
if req.grammar is not None:
req.grammar.accept_token(next_token_ids[i])
req.grammar.accept_token(next_token_id)
if req.return_logprob:
logprob_pt += self.add_logprob_return_values(