From 490a1f39dd54115b56e3c587b457cca49e0a9bfc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jul 2024 02:43:55 -0700 Subject: [PATCH] Fix cuda graph with flashinfer (#675) --- benchmark/gsm8k/bench_sglang.py | 2 +- python/sglang/srt/managers/controller/cuda_graph_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d1ed22cbe..298ec11d7 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -64,7 +64,7 @@ def main(args): @sgl.function def few_shot_gsm8k(s, question): s += few_shot_examples + question - s += sgl.gen("answer", max_tokens=256, stop="Question") + s += sgl.gen("answer", max_tokens=512, stop="Question") ##################################### ########## SGL Program End ########## diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index b37a82729..2a9a0af6d 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -150,8 +150,8 @@ class CudaGraphRunner: index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.zero_() - self.position_ids_offsets.fill_(1) + self.seq_lens.fill_(1) + self.position_ids_offsets.zero_() self.out_cache_loc.zero_() # Common inputs