Use int64 as indices for set_kv_buffer (#3039)

This commit is contained in:
Lianmin Zheng
2025-01-21 19:46:09 -08:00
committed by GitHub
parent a42213dbd4
commit 3d8f1c9bcf
6 changed files with 30 additions and 37 deletions

View File

@@ -550,13 +550,13 @@ class ScheduleBatch:
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None
input_embeds: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
input_ids: torch.Tensor = None # shape: [b], int32
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
req_pool_indices: torch.Tensor = None # shape: [b], int32
seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
out_cache_loc: torch.Tensor = None # shape: [b], int32
output_ids: torch.Tensor = None # shape: [b], int32
# The sum of all sequence lengths
seq_lens_sum: int = None
@@ -1026,7 +1026,7 @@ class ScheduleBatch:
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(