From 8f790ac1005cfb5403a0a1e847bb0e050a4282da Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 14 Aug 2024 03:25:38 -0700 Subject: [PATCH] Fix a bug in cuda graph runner (#1094) --- python/sglang/srt/model_executor/cuda_graph_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9bfd4a646..a74e8eef7 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -98,8 +98,8 @@ class CudaGraphRunner: self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" ) - self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") - self.position_ids_offsets = torch.zeros( + self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) self.out_cache_loc = torch.zeros( @@ -201,7 +201,7 @@ class CudaGraphRunner: out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, - positions=(seq_lens - 1).to(torch.int64), + positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) @@ -225,8 +225,8 @@ class CudaGraphRunner: index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.fill_(1) - self.position_ids_offsets.zero_() + self.seq_lens.zero_() + self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs