Add an option to disable penalizer (#1651)

This commit is contained in:
Lianmin Zheng
2024-10-12 17:53:23 -07:00
committed by GitHub
parent 69aa937aa5
commit 9da5a60b18
5 changed files with 111 additions and 90 deletions

View File

@@ -671,9 +671,10 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if logits_output:
# Move logprobs to cpu
@@ -755,9 +756,10 @@ class Scheduler:
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu