From 2bd18e2d767e3a0f8afb5aff427bc8e6e4d297c0 Mon Sep 17 00:00:00 2001 From: Yang Zheng <50227060+zhengy001@users.noreply.github.com> Date: Sun, 19 Jan 2025 11:35:12 +0800 Subject: [PATCH] Memory pool: Minor optimize to avoid to (#2901) --- python/sglang/srt/managers/schedule_batch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index faf05a7ff..77e5faca4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -668,7 +668,7 @@ class ScheduleBatch: or len(req.prefix_indices) >= im.num_image_tokens ) - self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( + self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -702,7 +702,7 @@ class ScheduleBatch: self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -778,10 +778,10 @@ class ScheduleBatch: self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( + self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) self.input_embeds = ( @@ -1014,9 +1014,9 @@ class ScheduleBatch: def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) - self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) - self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) + self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 self.sampling_info = SamplingBatchInfo.from_schedule_batch( @@ -1084,7 +1084,7 @@ class ScheduleBatch: self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] - new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( + new_indices = torch.tensor(keep_indices, dtype=torch.int64).to( self.device, non_blocking=True ) self.req_pool_indices = self.req_pool_indices[new_indices]