Add dtype for more operations (#1705)
This commit is contained in:
@@ -537,8 +537,8 @@ class ScheduleBatch:
|
||||
# Set fields
|
||||
with out_cache_loc.device:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices)
|
||||
self.seq_lens = torch.tensor(seq_lens)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
@@ -145,8 +145,9 @@ class ForwardBatch:
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
).to(torch.int64)
|
||||
)
|
||||
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
||||
|
||||
@@ -57,7 +57,7 @@ class SamplingBatchInfo:
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
|
||||
Reference in New Issue
Block a user