[Minor] Improve the style and fix flaky tests (#1584)
This commit is contained in:
@@ -40,7 +40,7 @@ class SamplingBatchInfo:
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
with torch.device("cuda"):
|
||||
with batch.input_ids.device:
|
||||
temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
|
||||
Reference in New Issue
Block a user