Add dtype for more operations (#1705)

This commit is contained in:
Lianmin Zheng
2024-10-18 12:18:15 -07:00
committed by GitHub
parent 6d0fa73ece
commit 392f2863c8
3 changed files with 5 additions and 4 deletions

View File

@@ -537,8 +537,8 @@ class ScheduleBatch:
# Set fields
with out_cache_loc.device:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices)
self.seq_lens = torch.tensor(seq_lens)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc