diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9bfd4a646..a74e8eef7 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -98,8 +98,8 @@ class CudaGraphRunner: self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" ) - self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") - self.position_ids_offsets = torch.zeros( + self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) self.out_cache_loc = torch.zeros( @@ -201,7 +201,7 @@ class CudaGraphRunner: out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, - positions=(seq_lens - 1).to(torch.int64), + positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) @@ -225,8 +225,8 @@ class CudaGraphRunner: index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.fill_(1) - self.position_ids_offsets.zero_() + self.seq_lens.zero_() + self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs