Add an option to disable penalizer (#1651)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
|
||||
regex_fsm_states: List[int] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
linear_penalties: torch.Tensor = None
|
||||
scaling_penalties: torch.Tensor = None
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
linear_penalties: Optional[torch.Tensor] = None
|
||||
scaling_penalties: Optional[torch.Tensor] = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
def from_schedule_batch(
|
||||
cls,
|
||||
batch: ScheduleBatch,
|
||||
vocab_size: int,
|
||||
disable_penalizer: bool,
|
||||
):
|
||||
reqs = batch.reqs
|
||||
with batch.input_ids.device:
|
||||
temperatures = torch.tensor(
|
||||
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.input_ids.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
if disable_penalizer:
|
||||
ret.penalizer_orchestrator = None
|
||||
else:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.input_ids.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
|
||||
return len(self.temperatures)
|
||||
|
||||
def update_penalties(self):
|
||||
if not self.penalizer_orchestrator:
|
||||
return
|
||||
|
||||
self.scaling_penalties = None
|
||||
self.linear_penalties = None
|
||||
|
||||
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||
if not has_regex:
|
||||
self.vocab_mask = None
|
||||
return
|
||||
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
self.vocab_mask = torch.zeros(
|
||||
len(self.temperatures),
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||
] = 0
|
||||
self.vocab_mask = torch.zeros(
|
||||
len(self.temperatures),
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||
] = 0
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
if self.penalizer_orchestrator:
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
|
||||
return None
|
||||
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
if self.penalizer_orchestrator:
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
|
||||
Reference in New Issue
Block a user