From 392f2863c8da8697fd7fb6f72222a9c82f198ed6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 18 Oct 2024 12:18:15 -0700 Subject: [PATCH] Add dtype for more operations (#1705) --- python/sglang/srt/managers/schedule_batch.py | 4 ++-- python/sglang/srt/model_executor/forward_batch_info.py | 3 ++- python/sglang/srt/sampling/sampling_batch_info.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f64b0699a..6cd5127bd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -537,8 +537,8 @@ class ScheduleBatch: # Set fields with out_cache_loc.device: self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) - self.req_pool_indices = torch.tensor(req_pool_indices) - self.seq_lens = torch.tensor(seq_lens) + self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fb7109464..555e3db95 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -145,8 +145,9 @@ class ForwardBatch: ], axis=0, ), + dtype=torch.int64, device=device, - ).to(torch.int64) + ) ret.image_inputs = batch.image_inputs ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 779af5101..cc4229ff5 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -57,7 +57,7 @@ class SamplingBatchInfo: [r.sampling_params.top_p for r in reqs], dtype=torch.float ) top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int + [r.sampling_params.top_k for r in reqs], dtype=torch.int32 ) min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float