Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -27,10 +27,10 @@ class SamplingBatchInfo:
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
grammars: Optional[List] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -211,25 +211,3 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
def copy(self):
return SamplingBatchInfo(
temperatures=self.temperatures,
top_ps=self.top_ps,
top_ks=self.top_ks,
min_ps=self.min_ps,
is_all_greedy=self.is_all_greedy,
need_min_p_sampling=self.need_min_p_sampling,
vocab_size=self.vocab_size,
device=self.device,
)
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))