Simplify the usage of device (#1734)

This commit is contained in:
Lianmin Zheng
2024-10-20 18:17:41 -07:00
committed by GitHub
parent 554fbf93cd
commit e12358dc91
3 changed files with 29 additions and 23 deletions

View File

@@ -51,7 +51,7 @@ class SamplingBatchInfo:
disable_penalizer: bool,
):
reqs = batch.reqs
device = batch.input_ids.device
device = batch.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=batch.input_ids.device,
device=batch.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,