Memory pool: Minor optimize to avoid to (#2901)
This commit is contained in:
@@ -668,7 +668,7 @@ class ScheduleBatch:
|
||||
or len(req.prefix_indices) >= im.num_image_tokens
|
||||
)
|
||||
|
||||
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
||||
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -702,7 +702,7 @@ class ScheduleBatch:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -778,10 +778,10 @@ class ScheduleBatch:
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_embeds = (
|
||||
@@ -1014,9 +1014,9 @@ class ScheduleBatch:
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
@@ -1084,7 +1084,7 @@ class ScheduleBatch:
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
|
||||
Reference in New Issue
Block a user