Use int64 as indices for set_kv_buffer (#3039)
This commit is contained in:
@@ -550,13 +550,13 @@ class ScheduleBatch:
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
input_embeds: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
input_ids: torch.Tensor = None # shape: [b], int32
|
||||
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
out_cache_loc: torch.Tensor = None
|
||||
output_ids: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
||||
output_ids: torch.Tensor = None # shape: [b], int32
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
@@ -1026,7 +1026,7 @@ class ScheduleBatch:
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
|
||||
Reference in New Issue
Block a user