Add dtype for more operations (#1705)

This commit is contained in:
Lianmin Zheng
2024-10-18 12:18:15 -07:00
committed by GitHub
parent 6d0fa73ece
commit 392f2863c8
3 changed files with 5 additions and 4 deletions

View File

@@ -145,8 +145,9 @@ class ForwardBatch:
],
axis=0,
),
dtype=torch.int64,
device=device,
).to(torch.int64)
)
ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)