Simplify the usage of device (#1734)
This commit is contained in:
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
|
||||
disable_penalizer: bool,
|
||||
):
|
||||
reqs = batch.reqs
|
||||
device = batch.input_ids.device
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.input_ids.device,
|
||||
device=batch.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
|
||||
Reference in New Issue
Block a user