Add an option to disable penalizer (#1651)
This commit is contained in:
@@ -531,7 +531,9 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
self.forward_mode = ForwardMode.MIXED
|
||||
|
||||
@@ -671,9 +671,10 @@ class Scheduler:
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids = result
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
if batch.sampling_info.penalizer_orchestrator:
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
if logits_output:
|
||||
# Move logprobs to cpu
|
||||
@@ -755,9 +756,10 @@ class Scheduler:
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
if batch.sampling_info.penalizer_orchestrator:
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
# Move logprobs to cpu
|
||||
|
||||
Reference in New Issue
Block a user