Make input_ids a torch.Tensor (#1568)
This commit is contained in:
@@ -123,7 +123,7 @@ class ForwardBatch:
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
|
||||
input_ids=batch.input_ids,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
|
||||
Reference in New Issue
Block a user