Simplify logits penalizer (#2086)
This commit is contained in:
@@ -27,10 +27,10 @@ class SamplingBatchInfo:
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
grammars: Optional[List] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
@@ -211,25 +211,3 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
return SamplingBatchInfo(
|
||||
temperatures=self.temperatures,
|
||||
top_ps=self.top_ps,
|
||||
top_ks=self.top_ks,
|
||||
min_ps=self.min_ps,
|
||||
is_all_greedy=self.is_all_greedy,
|
||||
need_min_p_sampling=self.need_min_p_sampling,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
value = getattr(self, item)
|
||||
setattr(self, item, value.to(device, non_blocking=True))
|
||||
|
||||
Reference in New Issue
Block a user