Make input_ids a torch.Tensor (#1568)

This commit is contained in:
Lianmin Zheng
2024-10-04 01:09:59 -07:00
committed by GitHub
parent 114bbc8651
commit 45473d4b2b
3 changed files with 11 additions and 7 deletions

View File

@@ -123,7 +123,7 @@ class ForwardBatch:
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
input_ids=batch.input_ids,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,