Add dtype for more operations (#1705)
This commit is contained in:
@@ -537,8 +537,8 @@ class ScheduleBatch:
|
||||
# Set fields
|
||||
with out_cache_loc.device:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices)
|
||||
self.seq_lens = torch.tensor(seq_lens)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
Reference in New Issue
Block a user